# Copyright (C) 2018-2019 Garmin Ltd. # # SPDX-License-Identifier: GPL-2.0-only # from http.server import BaseHTTPRequestHandler, HTTPServer import contextlib import urllib.parse import sqlite3 import json import traceback import logging import socketserver import queue import threading from datetime import datetime logger = logging.getLogger('hashserv') class HashEquivalenceServer(BaseHTTPRequestHandler): def log_message(self, f, *args): logger.debug(f, *args) def opendb(self): self.db = sqlite3.connect(self.dbname) self.db.row_factory = sqlite3.Row self.db.execute("PRAGMA synchronous = OFF;") self.db.execute("PRAGMA journal_mode = MEMORY;") def do_GET(self): try: if not self.db: self.opendb() p = urllib.parse.urlparse(self.path) if p.path != self.prefix + '/v1/equivalent': self.send_error(404) return query = urllib.parse.parse_qs(p.query, strict_parsing=True) method = query['method'][0] taskhash = query['taskhash'][0] d = None with contextlib.closing(self.db.cursor()) as cursor: cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', {'method': method, 'taskhash': taskhash}) row = cursor.fetchone() if row is not None: logger.debug('Found equivalent task %s', row['taskhash']) d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(d).encode('utf-8')) except: logger.exception('Error in GET') self.send_error(400, explain=traceback.format_exc()) return def do_POST(self): try: if not self.db: self.opendb() p = urllib.parse.urlparse(self.path) if p.path != self.prefix + '/v1/equivalent': self.send_error(404) return length = int(self.headers['content-length']) data = json.loads(self.rfile.read(length).decode('utf-8')) with contextlib.closing(self.db.cursor()) as cursor: cursor.execute(''' -- Find tasks with a matching outhash (that is, tasks that -- are equivalent) SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash -- If there is an exact match on the taskhash, return it. -- Otherwise return the oldest matching outhash of any -- taskhash ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, created ASC -- Only return one row LIMIT 1 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')}) row = cursor.fetchone() # If no matching outhash was found, or one *was* found but it # wasn't an exact match on the taskhash, a new entry for this # taskhash should be added if row is None or row['taskhash'] != data['taskhash']: # If a row matching the outhash was found, the unihash for # the new taskhash should be the same as that one. # Otherwise the caller provided unihash is used. unihash = data['unihash'] if row is not None: unihash = row['unihash'] insert_data = { 'method': data['method'], 'outhash': data['outhash'], 'taskhash': data['taskhash'], 'unihash': unihash, 'created': datetime.now() } for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): if k in data: insert_data[k] = data[k] cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( ', '.join(sorted(insert_data.keys())), ', '.join(':' + k for k in sorted(insert_data.keys()))), insert_data) logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash) self.db.commit() d = {'taskhash': data['taskhash'], 'method': data['method'], 'unihash': unihash} else: d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} self.send_response(200) self.send_header('Content-Type', 'application/json; charset=utf-8') self.end_headers() self.wfile.write(json.dumps(d).encode('utf-8')) except: logger.exception('Error in POST') self.send_error(400, explain=traceback.format_exc()) return class ThreadedHTTPServer(HTTPServer): quit = False def serve_forever(self): self.requestqueue = queue.Queue() self.handlerthread = threading.Thread(target=self.process_request_thread) self.handlerthread.daemon = False self.handlerthread.start() super().serve_forever() def process_request_thread(self): while not self.quit: try: (request, client_address) = self.requestqueue.get(True) except queue.Empty: continue if request is None: continue try: self.finish_request(request, client_address) except Exception: self.handle_error(request, client_address) finally: self.shutdown_request(request) def process_request(self, request, client_address): self.requestqueue.put((request, client_address)) def server_close(self): super().server_close() self.quit = True self.requestqueue.put((None, None)) self.handlerthread.join() def create_server(addr, dbname, prefix=''): class Handler(HashEquivalenceServer): pass db = sqlite3.connect(dbname) db.row_factory = sqlite3.Row Handler.prefix = prefix Handler.db = None Handler.dbname = dbname with contextlib.closing(db.cursor()) as cursor: cursor.execute(''' CREATE TABLE IF NOT EXISTS tasks_v2 ( id INTEGER PRIMARY KEY AUTOINCREMENT, method TEXT NOT NULL, outhash TEXT NOT NULL, taskhash TEXT NOT NULL, unihash TEXT NOT NULL, created DATETIME, -- Optional fields owner TEXT, PN TEXT, PV TEXT, PR TEXT, task TEXT, outhash_siginfo TEXT, UNIQUE(method, outhash, taskhash) ) ''') cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup ON tasks_v2 (method, taskhash)') cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup ON tasks_v2 (method, outhash)') logger.info('Starting server on %s', addr) return ThreadedHTTPServer(addr, Handler)