diff options
Diffstat (limited to 'lib/hashserv')
-rw-r--r-- | lib/hashserv/__init__.py | 177 | ||||
-rw-r--r-- | lib/hashserv/client.py | 312 | ||||
-rw-r--r-- | lib/hashserv/server.py | 946 | ||||
-rw-r--r-- | lib/hashserv/sqlalchemy.py | 598 | ||||
-rw-r--r-- | lib/hashserv/sqlite.py | 562 | ||||
-rw-r--r-- | lib/hashserv/tests.py | 1270 |
6 files changed, 3416 insertions, 449 deletions
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py index 5f2e101e5..74367eb6b 100644 --- a/lib/hashserv/__init__.py +++ b/lib/hashserv/__init__.py @@ -5,129 +5,102 @@ import asyncio from contextlib import closing -import re -import sqlite3 import itertools import json +from collections import namedtuple +from urllib.parse import urlparse +from bb.asyncrpc.client import parse_address, ADDR_TYPE_UNIX, ADDR_TYPE_WS + +User = namedtuple("User", ("username", "permissions")) + +def create_server( + addr, + dbname, + *, + sync=True, + upstream=None, + read_only=False, + db_username=None, + db_password=None, + anon_perms=None, + admin_username=None, + admin_password=None, +): + def sqlite_engine(): + from .sqlite import DatabaseEngine + + return DatabaseEngine(dbname, sync) + + def sqlalchemy_engine(): + from .sqlalchemy import DatabaseEngine + + return DatabaseEngine(dbname, db_username, db_password) -UNIX_PREFIX = "unix://" - -ADDR_TYPE_UNIX = 0 -ADDR_TYPE_TCP = 1 - -# The Python async server defaults to a 64K receive buffer, so we hardcode our -# maximum chunk size. It would be better if the client and server reported to -# each other what the maximum chunk sizes were, but that will slow down the -# connection setup with a round trip delay so I'd rather not do that unless it -# is necessary -DEFAULT_MAX_CHUNK = 32 * 1024 - -TABLE_DEFINITION = ( - ("method", "TEXT NOT NULL"), - ("outhash", "TEXT NOT NULL"), - ("taskhash", "TEXT NOT NULL"), - ("unihash", "TEXT NOT NULL"), - ("created", "DATETIME"), - - # Optional fields - ("owner", "TEXT"), - ("PN", "TEXT"), - ("PV", "TEXT"), - ("PR", "TEXT"), - ("task", "TEXT"), - ("outhash_siginfo", "TEXT"), -) - -TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION) - -def setup_database(database, sync=True): - db = sqlite3.connect(database) - db.row_factory = sqlite3.Row - - with closing(db.cursor()) as cursor: - cursor.execute(''' - CREATE TABLE IF NOT EXISTS tasks_v2 ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - %s - UNIQUE(method, outhash, taskhash) - ) - ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION)) - cursor.execute('PRAGMA journal_mode = WAL') - cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) - - # Drop old indexes - cursor.execute('DROP INDEX IF EXISTS taskhash_lookup') - cursor.execute('DROP INDEX IF EXISTS outhash_lookup') - - # Create new indexes - cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)') - cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)') - - return db - - -def parse_address(addr): - if addr.startswith(UNIX_PREFIX): - return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],)) - else: - m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr) - if m is not None: - host = m.group('host') - port = m.group('port') - else: - host, port = addr.split(':') - - return (ADDR_TYPE_TCP, (host, int(port))) - + from . import server -def chunkify(msg, max_chunk): - if len(msg) < max_chunk - 1: - yield ''.join((msg, "\n")) + if "://" in dbname: + db_engine = sqlalchemy_engine() else: - yield ''.join((json.dumps({ - 'chunk-stream': None - }), "\n")) + db_engine = sqlite_engine() - args = [iter(msg)] * (max_chunk - 1) - for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): - yield ''.join(itertools.chain(m, "\n")) - yield "\n" + if anon_perms is None: + anon_perms = server.DEFAULT_ANON_PERMS - -def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False): - from . import server - db = setup_database(dbname, sync=sync) - s = server.Server(db, upstream=upstream, read_only=read_only) + s = server.Server( + db_engine, + upstream=upstream, + read_only=read_only, + anon_perms=anon_perms, + admin_username=admin_username, + admin_password=admin_password, + ) (typ, a) = parse_address(addr) if typ == ADDR_TYPE_UNIX: s.start_unix_server(*a) + elif typ == ADDR_TYPE_WS: + url = urlparse(a[0]) + s.start_websocket_server(url.hostname, url.port) else: s.start_tcp_server(*a) return s -def create_client(addr): +def create_client(addr, username=None, password=None): from . import client - c = client.Client() - (typ, a) = parse_address(addr) - if typ == ADDR_TYPE_UNIX: - c.connect_unix(*a) - else: - c.connect_tcp(*a) + c = client.Client(username, password) + + try: + (typ, a) = parse_address(addr) + if typ == ADDR_TYPE_UNIX: + c.connect_unix(*a) + elif typ == ADDR_TYPE_WS: + c.connect_websocket(*a) + else: + c.connect_tcp(*a) + return c + except Exception as e: + c.close() + raise e - return c -async def create_async_client(addr): +async def create_async_client(addr, username=None, password=None): from . import client - c = client.AsyncClient() - (typ, a) = parse_address(addr) - if typ == ADDR_TYPE_UNIX: - await c.connect_unix(*a) - else: - await c.connect_tcp(*a) + c = client.AsyncClient(username, password) + + try: + (typ, a) = parse_address(addr) + if typ == ADDR_TYPE_UNIX: + await c.connect_unix(*a) + elif typ == ADDR_TYPE_WS: + await c.connect_websocket(*a) + else: + await c.connect_tcp(*a) - return c + return c + except Exception as e: + await c.close() + raise e diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py index 531170967..0b254bedd 100644 --- a/lib/hashserv/client.py +++ b/lib/hashserv/client.py @@ -3,12 +3,10 @@ # SPDX-License-Identifier: GPL-2.0-only # -import asyncio -import json import logging import socket -import os import bb.asyncrpc +import json from . import create_async_client @@ -18,107 +16,331 @@ logger = logging.getLogger("hashserv.client") class AsyncClient(bb.asyncrpc.AsyncClient): MODE_NORMAL = 0 MODE_GET_STREAM = 1 + MODE_EXIST_STREAM = 2 - def __init__(self): - super().__init__('OEHASHEQUIV', '1.1', logger) + def __init__(self, username=None, password=None): + super().__init__("OEHASHEQUIV", "1.1", logger) self.mode = self.MODE_NORMAL + self.username = username + self.password = password + self.saved_become_user = None async def setup_connection(self): await super().setup_connection() - cur_mode = self.mode self.mode = self.MODE_NORMAL - await self._set_mode(cur_mode) + if self.username: + # Save off become user temporarily because auth() resets it + become = self.saved_become_user + await self.auth(self.username, self.password) - async def send_stream(self, msg): + if become: + await self.become_user(become) + + async def send_stream(self, mode, msg): async def proc(): - self.writer.write(("%s\n" % msg).encode("utf-8")) - await self.writer.drain() - l = await self.reader.readline() - if not l: - raise ConnectionError("Connection closed") - return l.decode("utf-8").rstrip() + await self._set_mode(mode) + await self.socket.send(msg) + return await self.socket.recv() return await self._send_wrapper(proc) + async def invoke(self, *args, **kwargs): + # It's OK if connection errors cause a failure here, because the mode + # is also reset to normal on a new connection + await self._set_mode(self.MODE_NORMAL) + return await super().invoke(*args, **kwargs) + async def _set_mode(self, new_mode): - if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM: - r = await self.send_stream("END") + async def stream_to_normal(): + await self.socket.send("END") + return await self.socket.recv() + + async def normal_to_stream(command): + r = await self.invoke({command: None}) if r != "ok": - raise ConnectionError("Bad response from server %r" % r) - elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL: - r = await self.send_message({"get-stream": None}) + raise ConnectionError( + f"Unable to transition to stream mode: Bad response from server {r!r}" + ) + + self.logger.debug("Mode is now %s", command) + + if new_mode == self.mode: + return + + self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode) + + # Always transition to normal mode before switching to any other mode + if self.mode != self.MODE_NORMAL: + r = await self._send_wrapper(stream_to_normal) if r != "ok": - raise ConnectionError("Bad response from server %r" % r) - elif new_mode != self.mode: - raise Exception( - "Undefined mode transition %r -> %r" % (self.mode, new_mode) - ) + self.check_invoke_error(r) + raise ConnectionError( + f"Unable to transition to normal mode: Bad response from server {r!r}" + ) + self.logger.debug("Mode is now normal") + + if new_mode == self.MODE_GET_STREAM: + await normal_to_stream("get-stream") + elif new_mode == self.MODE_EXIST_STREAM: + await normal_to_stream("exists-stream") + elif new_mode != self.MODE_NORMAL: + raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}") self.mode = new_mode async def get_unihash(self, method, taskhash): - await self._set_mode(self.MODE_GET_STREAM) - r = await self.send_stream("%s %s" % (method, taskhash)) + r = await self.send_stream(self.MODE_GET_STREAM, "%s %s" % (method, taskhash)) if not r: return None return r async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): - await self._set_mode(self.MODE_NORMAL) m = extra.copy() m["taskhash"] = taskhash m["method"] = method m["outhash"] = outhash m["unihash"] = unihash - return await self.send_message({"report": m}) + return await self.invoke({"report": m}) async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): - await self._set_mode(self.MODE_NORMAL) m = extra.copy() m["taskhash"] = taskhash m["method"] = method m["unihash"] = unihash - return await self.send_message({"report-equiv": m}) + return await self.invoke({"report-equiv": m}) async def get_taskhash(self, method, taskhash, all_properties=False): - await self._set_mode(self.MODE_NORMAL) - return await self.send_message( + return await self.invoke( {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} ) - async def get_outhash(self, method, outhash, taskhash): - await self._set_mode(self.MODE_NORMAL) - return await self.send_message( - {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method}} + async def unihash_exists(self, unihash): + r = await self.send_stream(self.MODE_EXIST_STREAM, unihash) + return r == "true" + + async def get_outhash(self, method, outhash, taskhash, with_unihash=True): + return await self.invoke( + { + "get-outhash": { + "outhash": outhash, + "taskhash": taskhash, + "method": method, + "with_unihash": with_unihash, + } + } ) async def get_stats(self): - await self._set_mode(self.MODE_NORMAL) - return await self.send_message({"get-stats": None}) + return await self.invoke({"get-stats": None}) async def reset_stats(self): - await self._set_mode(self.MODE_NORMAL) - return await self.send_message({"reset-stats": None}) + return await self.invoke({"reset-stats": None}) async def backfill_wait(self): - await self._set_mode(self.MODE_NORMAL) - return (await self.send_message({"backfill-wait": None}))["tasks"] + return (await self.invoke({"backfill-wait": None}))["tasks"] + + async def remove(self, where): + return await self.invoke({"remove": {"where": where}}) + + async def clean_unused(self, max_age): + return await self.invoke({"clean-unused": {"max_age_seconds": max_age}}) + + async def auth(self, username, token): + result = await self.invoke({"auth": {"username": username, "token": token}}) + self.username = username + self.password = token + self.saved_become_user = None + return result + + async def refresh_token(self, username=None): + m = {} + if username: + m["username"] = username + result = await self.invoke({"refresh-token": m}) + if ( + self.username + and not self.saved_become_user + and result["username"] == self.username + ): + self.password = result["token"] + return result + + async def set_user_perms(self, username, permissions): + return await self.invoke( + {"set-user-perms": {"username": username, "permissions": permissions}} + ) + + async def get_user(self, username=None): + m = {} + if username: + m["username"] = username + return await self.invoke({"get-user": m}) + + async def get_all_users(self): + return (await self.invoke({"get-all-users": {}}))["users"] + + async def new_user(self, username, permissions): + return await self.invoke( + {"new-user": {"username": username, "permissions": permissions}} + ) + + async def delete_user(self, username): + return await self.invoke({"delete-user": {"username": username}}) + + async def become_user(self, username): + result = await self.invoke({"become-user": {"username": username}}) + if username == self.username: + self.saved_become_user = None + else: + self.saved_become_user = username + return result + + async def get_db_usage(self): + return (await self.invoke({"get-db-usage": {}}))["usage"] + + async def get_db_query_columns(self): + return (await self.invoke({"get-db-query-columns": {}}))["columns"] + + async def gc_status(self): + return await self.invoke({"gc-status": {}}) + + async def gc_mark(self, mark, where): + """ + Starts a new garbage collection operation identified by "mark". If + garbage collection is already in progress with "mark", the collection + is continued. + + All unihash entries that match the "where" clause are marked to be + kept. In addition, any new entries added to the database after this + command will be automatically marked with "mark" + """ + return await self.invoke({"gc-mark": {"mark": mark, "where": where}}) + + async def gc_sweep(self, mark): + """ + Finishes garbage collection for "mark". All unihash entries that have + not been marked will be deleted. + + It is recommended to clean unused outhash entries after running this to + cleanup any dangling outhashes + """ + return await self.invoke({"gc-sweep": {"mark": mark}}) class Client(bb.asyncrpc.Client): - def __init__(self): + def __init__(self, username=None, password=None): + self.username = username + self.password = password + super().__init__() self._add_methods( "connect_tcp", - "close", + "connect_websocket", "get_unihash", "report_unihash", "report_unihash_equiv", "get_taskhash", + "unihash_exists", + "get_outhash", "get_stats", "reset_stats", "backfill_wait", + "remove", + "clean_unused", + "auth", + "refresh_token", + "set_user_perms", + "get_user", + "get_all_users", + "new_user", + "delete_user", + "become_user", + "get_db_usage", + "get_db_query_columns", + "gc_status", + "gc_mark", + "gc_sweep", ) def _get_async_client(self): - return AsyncClient() + return AsyncClient(self.username, self.password) + + +class ClientPool(bb.asyncrpc.ClientPool): + def __init__( + self, + address, + max_clients, + *, + username=None, + password=None, + become=None, + ): + super().__init__(max_clients) + self.address = address + self.username = username + self.password = password + self.become = become + + async def _new_client(self): + client = await create_async_client( + self.address, + username=self.username, + password=self.password, + ) + if self.become: + await client.become_user(self.become) + return client + + def _run_key_tasks(self, queries, call): + results = {key: None for key in queries.keys()} + + def make_task(key, args): + async def task(client): + nonlocal results + unihash = await call(client, args) + results[key] = unihash + + return task + + def gen_tasks(): + for key, args in queries.items(): + yield make_task(key, args) + + self.run_tasks(gen_tasks()) + return results + + def get_unihashes(self, queries): + """ + Query multiple unihashes in parallel. + + The queries argument is a dictionary with arbitrary key. The values + must be a tuple of (method, taskhash). + + Returns a dictionary with a corresponding key for each input key, and + the value is the queried unihash (which might be none if the query + failed) + """ + + async def call(client, args): + method, taskhash = args + return await client.get_unihash(method, taskhash) + + return self._run_key_tasks(queries, call) + + def unihashes_exist(self, queries): + """ + Query multiple unihash existence checks in parallel. + + The queries argument is a dictionary with arbitrary key. The values + must be a unihash. + + Returns a dictionary with a corresponding key for each input key, and + the value is True or False if the unihash is known by the server (or + None if there was a failure) + """ + + async def call(client, unihash): + return await client.unihash_exists(unihash) + + return self._run_key_tasks(queries, call) diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py index c941c0e9d..68f64f983 100644 --- a/lib/hashserv/server.py +++ b/lib/hashserv/server.py @@ -3,22 +3,51 @@ # SPDX-License-Identifier: GPL-2.0-only # -from contextlib import closing, contextmanager -from datetime import datetime +from datetime import datetime, timedelta import asyncio -import json import logging import math -import os -import signal -import socket -import sys import time -from . import create_async_client, TABLE_COLUMNS +import os +import base64 +import hashlib +from . import create_async_client import bb.asyncrpc +logger = logging.getLogger("hashserv.server") + + +# This permission only exists to match nothing +NONE_PERM = "@none" + +READ_PERM = "@read" +REPORT_PERM = "@report" +DB_ADMIN_PERM = "@db-admin" +USER_ADMIN_PERM = "@user-admin" +ALL_PERM = "@all" + +ALL_PERMISSIONS = { + READ_PERM, + REPORT_PERM, + DB_ADMIN_PERM, + USER_ADMIN_PERM, + ALL_PERM, +} + +DEFAULT_ANON_PERMS = ( + READ_PERM, + REPORT_PERM, + DB_ADMIN_PERM, +) + +TOKEN_ALGORITHM = "sha256" + +# 48 bytes of random data will result in 64 characters when base64 +# encoded. This number also ensures that the base64 encoding won't have any +# trailing '=' characters. +TOKEN_SIZE = 48 -logger = logging.getLogger('hashserv.server') +SALT_SIZE = 8 class Measurement(object): @@ -108,360 +137,745 @@ class Stats(object): return math.sqrt(self.s / (self.num - 1)) def todict(self): - return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} + return { + k: getattr(self, k) + for k in ("num", "total_time", "max_time", "average", "stdev") + } + + +token_refresh_semaphore = asyncio.Lock() + + +async def new_token(): + # Prevent malicious users from using this API to deduce the entropy + # pool on the server and thus be able to guess a token. *All* token + # refresh requests lock the same global semaphore and then sleep for a + # short time. The effectively rate limits the total number of requests + # than can be made across all clients to 10/second, which should be enough + # since you have to be an authenticated users to make the request in the + # first place + async with token_refresh_semaphore: + await asyncio.sleep(0.1) + raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK) + + return base64.b64encode(raw, b"._").decode("utf-8") + + +def new_salt(): + return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex() -def insert_task(cursor, data, ignore=False): - keys = sorted(data.keys()) - query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( - " OR IGNORE" if ignore else "", - ', '.join(keys), - ', '.join(':' + k for k in keys)) - cursor.execute(query, data) +def hash_token(algo, salt, token): + h = hashlib.new(algo) + h.update(salt.encode("utf-8")) + h.update(token.encode("utf-8")) + return ":".join([algo, salt, h.hexdigest()]) -async def copy_from_upstream(client, db, method, taskhash): - d = await client.get_taskhash(method, taskhash, True) - if d is not None: - # Filter out unknown columns - d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} - keys = sorted(d.keys()) - with closing(db.cursor()) as cursor: - insert_task(cursor, d) - db.commit() +def permissions(*permissions, allow_anon=True, allow_self_service=False): + """ + Function decorator that can be used to decorate an RPC function call and + check that the current users permissions match the require permissions. - return d + If allow_anon is True, the user will also be allowed to make the RPC call + if the anonymous user permissions match the permissions. -async def copy_outhash_from_upstream(client, db, method, outhash, taskhash): - d = await client.get_outhash(method, outhash, taskhash) - if d is not None: - # Filter out unknown columns - d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} - keys = sorted(d.keys()) + If allow_self_service is True, and the "username" property in the request + is the currently logged in user, or not specified, the user will also be + allowed to make the request. This allows users to access normal privileged + API, as long as they are only modifying their own user properties (e.g. + users can be allowed to reset their own token without @user-admin + permissions, but not the token for any other user. + """ - with closing(db.cursor()) as cursor: - insert_task(cursor, d) - db.commit() + def wrapper(func): + async def wrap(self, request): + if allow_self_service and self.user is not None: + username = request.get("username", self.user.username) + if username == self.user.username: + request["username"] = self.user.username + return await func(self, request) + + if not self.user_has_permissions(*permissions, allow_anon=allow_anon): + if not self.user: + username = "Anonymous user" + user_perms = self.server.anon_perms + else: + username = self.user.username + user_perms = self.user.permissions + + self.logger.info( + "User %s with permissions %r denied from calling %s. Missing permissions(s) %r", + username, + ", ".join(user_perms), + func.__name__, + ", ".join(permissions), + ) + raise bb.asyncrpc.InvokeError( + f"{username} is not allowed to access permissions(s) {', '.join(permissions)}" + ) + + return await func(self, request) + + return wrap + + return wrapper - return d class ServerClient(bb.asyncrpc.AsyncServerConnection): - FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' - ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' - OUTHASH_QUERY = ''' - -- Find tasks with a matching outhash (that is, tasks that - -- are equivalent) - SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash - - -- If there is an exact match on the taskhash, return it. - -- Otherwise return the oldest matching outhash of any - -- taskhash - ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, - created ASC - - -- Only return one row - LIMIT 1 - ''' - - def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only): - super().__init__(reader, writer, 'OEHASHEQUIV', logger) - self.db = db - self.request_stats = request_stats + def __init__(self, socket, server): + super().__init__(socket, "OEHASHEQUIV", server.logger) + self.server = server self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK - self.backfill_queue = backfill_queue - self.upstream = upstream + self.user = None + + self.handlers.update( + { + "get": self.handle_get, + "get-outhash": self.handle_get_outhash, + "get-stream": self.handle_get_stream, + "exists-stream": self.handle_exists_stream, + "get-stats": self.handle_get_stats, + "get-db-usage": self.handle_get_db_usage, + "get-db-query-columns": self.handle_get_db_query_columns, + # Not always read-only, but internally checks if the server is + # read-only + "report": self.handle_report, + "auth": self.handle_auth, + "get-user": self.handle_get_user, + "get-all-users": self.handle_get_all_users, + "become-user": self.handle_become_user, + } + ) + + if not self.server.read_only: + self.handlers.update( + { + "report-equiv": self.handle_equivreport, + "reset-stats": self.handle_reset_stats, + "backfill-wait": self.handle_backfill_wait, + "remove": self.handle_remove, + "gc-mark": self.handle_gc_mark, + "gc-sweep": self.handle_gc_sweep, + "gc-status": self.handle_gc_status, + "clean-unused": self.handle_clean_unused, + "refresh-token": self.handle_refresh_token, + "set-user-perms": self.handle_set_perms, + "new-user": self.handle_new_user, + "delete-user": self.handle_delete_user, + } + ) - self.handlers.update({ - 'get': self.handle_get, - 'get-outhash': self.handle_get_outhash, - 'get-stream': self.handle_get_stream, - 'get-stats': self.handle_get_stats, - }) - - if not read_only: - self.handlers.update({ - 'report': self.handle_report, - 'report-equiv': self.handle_equivreport, - 'reset-stats': self.handle_reset_stats, - 'backfill-wait': self.handle_backfill_wait, - }) + def raise_no_user_error(self, username): + raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists") + + def user_has_permissions(self, *permissions, allow_anon=True): + permissions = set(permissions) + if allow_anon: + if ALL_PERM in self.server.anon_perms: + return True + + if not permissions - self.server.anon_perms: + return True + + if self.user is None: + return False + + if ALL_PERM in self.user.permissions: + return True + + if not permissions - self.user.permissions: + return True + + return False def validate_proto_version(self): - return (self.proto_version > (1, 0) and self.proto_version <= (1, 1)) + return self.proto_version > (1, 0) and self.proto_version <= (1, 1) async def process_requests(self): - if self.upstream is not None: - self.upstream_client = await create_async_client(self.upstream) - else: - self.upstream_client = None - - await super().process_requests() + async with self.server.db_engine.connect(self.logger) as db: + self.db = db + if self.server.upstream is not None: + self.upstream_client = await create_async_client(self.server.upstream) + else: + self.upstream_client = None - if self.upstream_client is not None: - await self.upstream_client.close() + try: + await super().process_requests() + finally: + if self.upstream_client is not None: + await self.upstream_client.close() async def dispatch_message(self, msg): for k in self.handlers.keys(): if k in msg: - logger.debug('Handling %s' % k) - if 'stream' in k: - await self.handlers[k](msg[k]) + self.logger.debug("Handling %s" % k) + if "stream" in k: + return await self.handlers[k](msg[k]) else: - with self.request_stats.start_sample() as self.request_sample, \ - self.request_sample.measure(): - await self.handlers[k](msg[k]) - return + with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure(): + return await self.handlers[k](msg[k]) raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg) + @permissions(READ_PERM) async def handle_get(self, request): - method = request['method'] - taskhash = request['taskhash'] - - if request.get('all', False): - row = self.query_equivalent(method, taskhash, self.ALL_QUERY) + method = request["method"] + taskhash = request["taskhash"] + fetch_all = request.get("all", False) + + return await self.get_unihash(method, taskhash, fetch_all) + + async def get_unihash(self, method, taskhash, fetch_all=False): + d = None + + if fetch_all: + row = await self.db.get_unihash_by_taskhash_full(method, taskhash) + if row is not None: + d = {k: row[k] for k in row.keys()} + elif self.upstream_client is not None: + d = await self.upstream_client.get_taskhash(method, taskhash, True) + await self.update_unified(d) else: - row = self.query_equivalent(method, taskhash, self.FAST_QUERY) + row = await self.db.get_equivalent(method, taskhash) - if row is not None: - logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) - d = {k: row[k] for k in row.keys()} - elif self.upstream_client is not None: - d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) - else: - d = None + if row is not None: + d = {k: row[k] for k in row.keys()} + elif self.upstream_client is not None: + d = await self.upstream_client.get_taskhash(method, taskhash) + await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) - self.write_message(d) + return d + @permissions(READ_PERM) async def handle_get_outhash(self, request): - with closing(self.db.cursor()) as cursor: - cursor.execute(self.OUTHASH_QUERY, - {k: request[k] for k in ('method', 'outhash', 'taskhash')}) + method = request["method"] + outhash = request["outhash"] + taskhash = request["taskhash"] + with_unihash = request.get("with_unihash", True) - row = cursor.fetchone() + return await self.get_outhash(method, outhash, taskhash, with_unihash) + + async def get_outhash(self, method, outhash, taskhash, with_unihash=True): + d = None + if with_unihash: + row = await self.db.get_unihash_by_outhash(method, outhash) + else: + row = await self.db.get_outhash(method, outhash) if row is not None: - logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash'])) d = {k: row[k] for k in row.keys()} - else: - d = None + elif self.upstream_client is not None: + d = await self.upstream_client.get_outhash(method, outhash, taskhash) + await self.update_unified(d) - self.write_message(d) + return d - async def handle_get_stream(self, request): - self.write_message('ok') + async def update_unified(self, data): + if data is None: + return + + await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) + await self.db.insert_outhash(data) + + async def _stream_handler(self, handler): + await self.socket.send_message("ok") while True: upstream = None - l = await self.reader.readline() + l = await self.socket.recv() if not l: - return + break try: # This inner loop is very sensitive and must be as fast as # possible (which is why the request sample is handled manually # instead of using 'with', and also why logging statements are # commented out. - self.request_sample = self.request_stats.start_sample() + self.request_sample = self.server.request_stats.start_sample() request_measure = self.request_sample.measure() request_measure.start() - l = l.decode('utf-8').rstrip() - if l == 'END': - self.writer.write('ok\n'.encode('utf-8')) - return - - (method, taskhash) = l.split() - #logger.debug('Looking up %s %s' % (method, taskhash)) - row = self.query_equivalent(method, taskhash, self.FAST_QUERY) - if row is not None: - msg = ('%s\n' % row['unihash']).encode('utf-8') - #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) - elif self.upstream_client is not None: - upstream = await self.upstream_client.get_unihash(method, taskhash) - if upstream: - msg = ("%s\n" % upstream).encode("utf-8") - else: - msg = "\n".encode("utf-8") - else: - msg = '\n'.encode('utf-8') + if l == "END": + break - self.writer.write(msg) + msg = await handler(l) + await self.socket.send(msg) finally: request_measure.end() self.request_sample.end() - await self.writer.drain() + await self.socket.send("ok") + return self.NO_RESPONSE + + @permissions(READ_PERM) + async def handle_get_stream(self, request): + async def handler(l): + (method, taskhash) = l.split() + # self.logger.debug('Looking up %s %s' % (method, taskhash)) + row = await self.db.get_equivalent(method, taskhash) - # Post to the backfill queue after writing the result to minimize - # the turn around time on a request - if upstream is not None: - await self.backfill_queue.put((method, taskhash)) + if row is not None: + # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) + return row["unihash"] - async def handle_report(self, data): - with closing(self.db.cursor()) as cursor: - cursor.execute(self.OUTHASH_QUERY, - {k: data[k] for k in ('method', 'outhash', 'taskhash')}) - - row = cursor.fetchone() - - if row is None and self.upstream_client: - # Try upstream - row = await copy_outhash_from_upstream(self.upstream_client, - self.db, - data['method'], - data['outhash'], - data['taskhash']) - - # If no matching outhash was found, or one *was* found but it - # wasn't an exact match on the taskhash, a new entry for this - # taskhash should be added - if row is None or row['taskhash'] != data['taskhash']: - # If a row matching the outhash was found, the unihash for - # the new taskhash should be the same as that one. - # Otherwise the caller provided unihash is used. - unihash = data['unihash'] - if row is not None: - unihash = row['unihash'] - - insert_data = { - 'method': data['method'], - 'outhash': data['outhash'], - 'taskhash': data['taskhash'], - 'unihash': unihash, - 'created': datetime.now() - } + if self.upstream_client is not None: + upstream = await self.upstream_client.get_unihash(method, taskhash) + if upstream: + await self.server.backfill_queue.put((method, taskhash)) + return upstream - for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): - if k in data: - insert_data[k] = data[k] + return "" - insert_task(cursor, insert_data) - self.db.commit() + return await self._stream_handler(handler) - logger.info('Adding taskhash %s with unihash %s', - data['taskhash'], unihash) + @permissions(READ_PERM) + async def handle_exists_stream(self, request): + async def handler(l): + if await self.db.unihash_exists(l): + return "true" - d = { - 'taskhash': data['taskhash'], - 'method': data['method'], - 'unihash': unihash - } - else: - d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} + if self.upstream_client is not None: + if await self.upstream_client.unihash_exists(l): + return "true" - self.write_message(d) + return "false" - async def handle_equivreport(self, data): - with closing(self.db.cursor()) as cursor: - insert_data = { - 'method': data['method'], - 'outhash': "", - 'taskhash': data['taskhash'], - 'unihash': data['unihash'], - 'created': datetime.now() - } + return await self._stream_handler(handler) - for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): - if k in data: - insert_data[k] = data[k] + async def report_readonly(self, data): + method = data["method"] + outhash = data["outhash"] + taskhash = data["taskhash"] - insert_task(cursor, insert_data, ignore=True) - self.db.commit() + info = await self.get_outhash(method, outhash, taskhash) + if info: + unihash = info["unihash"] + else: + unihash = data["unihash"] - # Fetch the unihash that will be reported for the taskhash. If the - # unihash matches, it means this row was inserted (or the mapping - # was already valid) - row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) + return { + "taskhash": taskhash, + "method": method, + "unihash": unihash, + } + + # Since this can be called either read only or to report, the check to + # report is made inside the function + @permissions(READ_PERM) + async def handle_report(self, data): + if self.server.read_only or not self.user_has_permissions(REPORT_PERM): + return await self.report_readonly(data) + + outhash_data = { + "method": data["method"], + "outhash": data["outhash"], + "taskhash": data["taskhash"], + "created": datetime.now(), + } + + for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"): + if k in data: + outhash_data[k] = data[k] + + if self.user: + outhash_data["owner"] = self.user.username + + # Insert the new entry, unless it already exists + if await self.db.insert_outhash(outhash_data): + # If this row is new, check if it is equivalent to another + # output hash + row = await self.db.get_equivalent_for_outhash( + data["method"], data["outhash"], data["taskhash"] + ) + + if row is not None: + # A matching output hash was found. Set our taskhash to the + # same unihash since they are equivalent + unihash = row["unihash"] + else: + # No matching output hash was found. This is probably the + # first outhash to be added. + unihash = data["unihash"] + + # Query upstream to see if it has a unihash we can use + if self.upstream_client is not None: + upstream_data = await self.upstream_client.get_outhash( + data["method"], data["outhash"], data["taskhash"] + ) + if upstream_data is not None: + unihash = upstream_data["unihash"] + + await self.db.insert_unihash(data["method"], data["taskhash"], unihash) + + unihash_data = await self.get_unihash(data["method"], data["taskhash"]) + if unihash_data is not None: + unihash = unihash_data["unihash"] + else: + unihash = data["unihash"] + + return { + "taskhash": data["taskhash"], + "method": data["method"], + "unihash": unihash, + } - if row['unihash'] == data['unihash']: - logger.info('Adding taskhash equivalence for %s with unihash %s', - data['taskhash'], row['unihash']) + @permissions(READ_PERM, REPORT_PERM) + async def handle_equivreport(self, data): + await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) - d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} + # Fetch the unihash that will be reported for the taskhash. If the + # unihash matches, it means this row was inserted (or the mapping + # was already valid) + row = await self.db.get_equivalent(data["method"], data["taskhash"]) - self.write_message(d) + if row["unihash"] == data["unihash"]: + self.logger.info( + "Adding taskhash equivalence for %s with unihash %s", + data["taskhash"], + row["unihash"], + ) + return {k: row[k] for k in ("taskhash", "method", "unihash")} + @permissions(READ_PERM) async def handle_get_stats(self, request): - d = { - 'requests': self.request_stats.todict(), + return { + "requests": self.server.request_stats.todict(), } - self.write_message(d) - + @permissions(DB_ADMIN_PERM) async def handle_reset_stats(self, request): d = { - 'requests': self.request_stats.todict(), + "requests": self.server.request_stats.todict(), } - self.request_stats.reset() - self.write_message(d) + self.server.request_stats.reset() + return d + @permissions(READ_PERM) async def handle_backfill_wait(self, request): d = { - 'tasks': self.backfill_queue.qsize(), + "tasks": self.server.backfill_queue.qsize(), + } + await self.server.backfill_queue.join() + return d + + @permissions(DB_ADMIN_PERM) + async def handle_remove(self, request): + condition = request["where"] + if not isinstance(condition, dict): + raise TypeError("Bad condition type %s" % type(condition)) + + return {"count": await self.db.remove(condition)} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_mark(self, request): + condition = request["where"] + mark = request["mark"] + + if not isinstance(condition, dict): + raise TypeError("Bad condition type %s" % type(condition)) + + if not isinstance(mark, str): + raise TypeError("Bad mark type %s" % type(mark)) + + return {"count": await self.db.gc_mark(mark, condition)} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_sweep(self, request): + mark = request["mark"] + + if not isinstance(mark, str): + raise TypeError("Bad mark type %s" % type(mark)) + + current_mark = await self.db.get_current_gc_mark() + + if not current_mark or mark != current_mark: + raise bb.asyncrpc.InvokeError( + f"'{mark}' is not the current mark. Refusing to sweep" + ) + + count = await self.db.gc_sweep() + + return {"count": count} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_status(self, request): + (keep_rows, remove_rows, current_mark) = await self.db.gc_status() + return { + "keep": keep_rows, + "remove": remove_rows, + "mark": current_mark, } - await self.backfill_queue.join() - self.write_message(d) - def query_equivalent(self, method, taskhash, query): - # This is part of the inner loop and must be as fast as possible + @permissions(DB_ADMIN_PERM) + async def handle_clean_unused(self, request): + max_age = request["max_age_seconds"] + oldest = datetime.now() - timedelta(seconds=-max_age) + return {"count": await self.db.clean_unused(oldest)} + + @permissions(DB_ADMIN_PERM) + async def handle_get_db_usage(self, request): + return {"usage": await self.db.get_usage()} + + @permissions(DB_ADMIN_PERM) + async def handle_get_db_query_columns(self, request): + return {"columns": await self.db.get_query_columns()} + + # The authentication API is always allowed + async def handle_auth(self, request): + username = str(request["username"]) + token = str(request["token"]) + + async def fail_auth(): + nonlocal username + # Rate limit bad login attempts + await asyncio.sleep(1) + raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}") + + user, db_token = await self.db.lookup_user_token(username) + + if not user or not db_token: + await fail_auth() + + try: + algo, salt, _ = db_token.split(":") + except ValueError: + await fail_auth() + + if hash_token(algo, salt, token) != db_token: + await fail_auth() + + self.user = user + + self.logger.info("Authenticated as %s", username) + + return { + "result": True, + "username": self.user.username, + "permissions": sorted(list(self.user.permissions)), + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_refresh_token(self, request): + username = str(request["username"]) + + token = await new_token() + + updated = await self.db.set_user_token( + username, + hash_token(TOKEN_ALGORITHM, new_salt(), token), + ) + if not updated: + self.raise_no_user_error(username) + + return {"username": username, "token": token} + + def get_perm_arg(self, arg): + if not isinstance(arg, list): + raise bb.asyncrpc.InvokeError("Unexpected type for permissions") + + arg = set(arg) try: - cursor = self.db.cursor() - cursor.execute(query, {'method': method, 'taskhash': taskhash}) - return cursor.fetchone() - except: - cursor.close() + arg.remove(NONE_PERM) + except KeyError: + pass + + unknown_perms = arg - ALL_PERMISSIONS + if unknown_perms: + raise bb.asyncrpc.InvokeError( + "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms))) + ) + + return sorted(list(arg)) + + def return_perms(self, permissions): + if ALL_PERM in permissions: + return sorted(list(ALL_PERMISSIONS)) + return sorted(list(permissions)) + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_set_perms(self, request): + username = str(request["username"]) + permissions = self.get_perm_arg(request["permissions"]) + + if not await self.db.set_user_perms(username, permissions): + self.raise_no_user_error(username) + + return { + "username": username, + "permissions": self.return_perms(permissions), + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_get_user(self, request): + username = str(request["username"]) + + user = await self.db.lookup_user(username) + if user is None: + return None + + return { + "username": user.username, + "permissions": self.return_perms(user.permissions), + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_get_all_users(self, request): + users = await self.db.get_all_users() + return { + "users": [ + { + "username": u.username, + "permissions": self.return_perms(u.permissions), + } + for u in users + ] + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_new_user(self, request): + username = str(request["username"]) + permissions = self.get_perm_arg(request["permissions"]) + + token = await new_token() + + inserted = await self.db.new_user( + username, + permissions, + hash_token(TOKEN_ALGORITHM, new_salt(), token), + ) + if not inserted: + raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'") + + return { + "username": username, + "permissions": self.return_perms(permissions), + "token": token, + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_delete_user(self, request): + username = str(request["username"]) + + if not await self.db.delete_user(username): + self.raise_no_user_error(username) + + return {"username": username} + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_become_user(self, request): + username = str(request["username"]) + + user = await self.db.lookup_user(username) + if user is None: + raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist") + + self.user = user + + self.logger.info("Became user %s", username) + + return { + "username": self.user.username, + "permissions": self.return_perms(self.user.permissions), + } class Server(bb.asyncrpc.AsyncServer): - def __init__(self, db, loop=None, upstream=None, read_only=False): + def __init__( + self, + db_engine, + upstream=None, + read_only=False, + anon_perms=DEFAULT_ANON_PERMS, + admin_username=None, + admin_password=None, + ): if upstream and read_only: - raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server") + raise bb.asyncrpc.ServerError( + "Read-only hashserv cannot pull from an upstream server" + ) + + disallowed_perms = set(anon_perms) - set( + [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM] + ) - super().__init__(logger, loop) + if disallowed_perms: + raise bb.asyncrpc.ServerError( + f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users" + ) + + super().__init__(logger) self.request_stats = Stats() - self.db = db + self.db_engine = db_engine self.upstream = upstream self.read_only = read_only + self.backfill_queue = None + self.anon_perms = set(anon_perms) + self.admin_username = admin_username + self.admin_password = admin_password + + self.logger.info( + "Anonymous user permissions are: %s", ", ".join(self.anon_perms) + ) + + def accept_client(self, socket): + return ServerClient(socket, self) + + async def create_admin_user(self): + admin_permissions = (ALL_PERM,) + async with self.db_engine.connect(self.logger) as db: + added = await db.new_user( + self.admin_username, + admin_permissions, + hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), + ) + if added: + self.logger.info("Created admin user '%s'", self.admin_username) + else: + await db.set_user_perms( + self.admin_username, + admin_permissions, + ) + await db.set_user_token( + self.admin_username, + hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), + ) + self.logger.info("Admin user '%s' updated", self.admin_username) + + async def backfill_worker_task(self): + async with await create_async_client( + self.upstream + ) as client, self.db_engine.connect(self.logger) as db: + while True: + item = await self.backfill_queue.get() + if item is None: + self.backfill_queue.task_done() + break - def accept_client(self, reader, writer): - return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) + method, taskhash = item + d = await client.get_taskhash(method, taskhash) + if d is not None: + await db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) + self.backfill_queue.task_done() - @contextmanager - def _backfill_worker(self): - async def backfill_worker_task(): - client = await create_async_client(self.upstream) - try: - while True: - item = await self.backfill_queue.get() - if item is None: - self.backfill_queue.task_done() - break - method, taskhash = item - await copy_from_upstream(client, self.db, method, taskhash) - self.backfill_queue.task_done() - finally: - await client.close() + def start(self): + tasks = super().start() + if self.upstream: + self.backfill_queue = asyncio.Queue() + tasks += [self.backfill_worker_task()] - async def join_worker(worker): - await self.backfill_queue.put(None) - await worker + self.loop.run_until_complete(self.db_engine.create()) - if self.upstream is not None: - worker = asyncio.ensure_future(backfill_worker_task()) - try: - yield - finally: - self.loop.run_until_complete(join_worker(worker)) - else: - yield + if self.admin_username: + self.loop.run_until_complete(self.create_admin_user()) - def run_loop_forever(self): - self.backfill_queue = asyncio.Queue() + return tasks - with self._backfill_worker(): - super().run_loop_forever() + async def stop(self): + if self.backfill_queue is not None: + await self.backfill_queue.put(None) + await super().stop() diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py new file mode 100644 index 000000000..f7b0226a7 --- /dev/null +++ b/lib/hashserv/sqlalchemy.py @@ -0,0 +1,598 @@ +#! /usr/bin/env python3 +# +# Copyright (C) 2023 Garmin Ltd. +# +# SPDX-License-Identifier: GPL-2.0-only +# + +import logging +from datetime import datetime +from . import User + +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.pool import NullPool +from sqlalchemy import ( + MetaData, + Column, + Table, + Text, + Integer, + UniqueConstraint, + DateTime, + Index, + select, + insert, + exists, + literal, + and_, + delete, + update, + func, + inspect, +) +import sqlalchemy.engine +from sqlalchemy.orm import declarative_base +from sqlalchemy.exc import IntegrityError +from sqlalchemy.dialects.postgresql import insert as postgres_insert + +Base = declarative_base() + + +class UnihashesV3(Base): + __tablename__ = "unihashes_v3" + id = Column(Integer, primary_key=True, autoincrement=True) + method = Column(Text, nullable=False) + taskhash = Column(Text, nullable=False) + unihash = Column(Text, nullable=False) + gc_mark = Column(Text, nullable=False) + + __table_args__ = ( + UniqueConstraint("method", "taskhash"), + Index("taskhash_lookup_v4", "method", "taskhash"), + Index("unihash_lookup_v1", "unihash"), + ) + + +class OuthashesV2(Base): + __tablename__ = "outhashes_v2" + id = Column(Integer, primary_key=True, autoincrement=True) + method = Column(Text, nullable=False) + taskhash = Column(Text, nullable=False) + outhash = Column(Text, nullable=False) + created = Column(DateTime) + owner = Column(Text) + PN = Column(Text) + PV = Column(Text) + PR = Column(Text) + task = Column(Text) + outhash_siginfo = Column(Text) + + __table_args__ = ( + UniqueConstraint("method", "taskhash", "outhash"), + Index("outhash_lookup_v3", "method", "outhash"), + ) + + +class Users(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(Text, nullable=False) + token = Column(Text, nullable=False) + permissions = Column(Text) + + __table_args__ = (UniqueConstraint("username"),) + + +class Config(Base): + __tablename__ = "config" + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(Text, nullable=False) + value = Column(Text) + __table_args__ = ( + UniqueConstraint("name"), + Index("config_lookup", "name"), + ) + + +# +# Old table versions +# +DeprecatedBase = declarative_base() + + +class UnihashesV2(DeprecatedBase): + __tablename__ = "unihashes_v2" + id = Column(Integer, primary_key=True, autoincrement=True) + method = Column(Text, nullable=False) + taskhash = Column(Text, nullable=False) + unihash = Column(Text, nullable=False) + + __table_args__ = ( + UniqueConstraint("method", "taskhash"), + Index("taskhash_lookup_v3", "method", "taskhash"), + ) + + +class DatabaseEngine(object): + def __init__(self, url, username=None, password=None): + self.logger = logging.getLogger("hashserv.sqlalchemy") + self.url = sqlalchemy.engine.make_url(url) + + if username is not None: + self.url = self.url.set(username=username) + + if password is not None: + self.url = self.url.set(password=password) + + async def create(self): + def check_table_exists(conn, name): + return inspect(conn).has_table(name) + + self.logger.info("Using database %s", self.url) + if self.url.drivername == 'postgresql+psycopg': + # Psygopg 3 (psygopg) driver can handle async connection pooling + self.engine = create_async_engine(self.url, max_overflow=-1) + else: + self.engine = create_async_engine(self.url, poolclass=NullPool) + + async with self.engine.begin() as conn: + # Create tables + self.logger.info("Creating tables...") + await conn.run_sync(Base.metadata.create_all) + + if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__): + self.logger.info("Upgrading Unihashes V2 -> V3...") + statement = insert(UnihashesV3).from_select( + ["id", "method", "unihash", "taskhash", "gc_mark"], + select( + UnihashesV2.id, + UnihashesV2.method, + UnihashesV2.unihash, + UnihashesV2.taskhash, + literal("").label("gc_mark"), + ), + ) + self.logger.debug("%s", statement) + await conn.execute(statement) + + await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__]) + self.logger.info("Upgrade complete") + + def connect(self, logger): + return Database(self.engine, logger) + + +def map_row(row): + if row is None: + return None + return dict(**row._mapping) + + +def map_user(row): + if row is None: + return None + return User( + username=row.username, + permissions=set(row.permissions.split()), + ) + + +def _make_condition_statement(table, condition): + where = {} + for c in table.__table__.columns: + if c.key in condition and condition[c.key] is not None: + where[c] = condition[c.key] + + return [(k == v) for k, v in where.items()] + + +class Database(object): + def __init__(self, engine, logger): + self.engine = engine + self.db = None + self.logger = logger + + async def __aenter__(self): + self.db = await self.engine.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + async def close(self): + await self.db.close() + self.db = None + + async def _execute(self, statement): + self.logger.debug("%s", statement) + return await self.db.execute(statement) + + async def _set_config(self, name, value): + while True: + result = await self._execute( + update(Config).where(Config.name == name).values(value=value) + ) + + if result.rowcount == 0: + self.logger.debug("Config '%s' not found. Adding it", name) + try: + await self._execute(insert(Config).values(name=name, value=value)) + except IntegrityError: + # Race. Try again + continue + + break + + def _get_config_subquery(self, name, default=None): + if default is not None: + return func.coalesce( + select(Config.value).where(Config.name == name).scalar_subquery(), + default, + ) + return select(Config.value).where(Config.name == name).scalar_subquery() + + async def _get_config(self, name): + result = await self._execute(select(Config.value).where(Config.name == name)) + row = result.first() + if row is None: + return None + return row.value + + async def get_unihash_by_taskhash_full(self, method, taskhash): + async with self.db.begin(): + result = await self._execute( + select( + OuthashesV2, + UnihashesV3.unihash.label("unihash"), + ) + .join( + UnihashesV3, + and_( + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, + ), + ) + .where( + OuthashesV2.method == method, + OuthashesV2.taskhash == taskhash, + ) + .order_by( + OuthashesV2.created.asc(), + ) + .limit(1) + ) + return map_row(result.first()) + + async def get_unihash_by_outhash(self, method, outhash): + async with self.db.begin(): + result = await self._execute( + select(OuthashesV2, UnihashesV3.unihash.label("unihash")) + .join( + UnihashesV3, + and_( + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, + ), + ) + .where( + OuthashesV2.method == method, + OuthashesV2.outhash == outhash, + ) + .order_by( + OuthashesV2.created.asc(), + ) + .limit(1) + ) + return map_row(result.first()) + + async def unihash_exists(self, unihash): + async with self.db.begin(): + result = await self._execute( + select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1) + ) + + return result.first() is not None + + async def get_outhash(self, method, outhash): + async with self.db.begin(): + result = await self._execute( + select(OuthashesV2) + .where( + OuthashesV2.method == method, + OuthashesV2.outhash == outhash, + ) + .order_by( + OuthashesV2.created.asc(), + ) + .limit(1) + ) + return map_row(result.first()) + + async def get_equivalent_for_outhash(self, method, outhash, taskhash): + async with self.db.begin(): + result = await self._execute( + select( + OuthashesV2.taskhash.label("taskhash"), + UnihashesV3.unihash.label("unihash"), + ) + .join( + UnihashesV3, + and_( + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, + ), + ) + .where( + OuthashesV2.method == method, + OuthashesV2.outhash == outhash, + OuthashesV2.taskhash != taskhash, + ) + .order_by( + OuthashesV2.created.asc(), + ) + .limit(1) + ) + return map_row(result.first()) + + async def get_equivalent(self, method, taskhash): + async with self.db.begin(): + result = await self._execute( + select( + UnihashesV3.unihash, + UnihashesV3.method, + UnihashesV3.taskhash, + ).where( + UnihashesV3.method == method, + UnihashesV3.taskhash == taskhash, + ) + ) + return map_row(result.first()) + + async def remove(self, condition): + async def do_remove(table): + where = _make_condition_statement(table, condition) + if where: + async with self.db.begin(): + result = await self._execute(delete(table).where(*where)) + return result.rowcount + + return 0 + + count = 0 + count += await do_remove(UnihashesV3) + count += await do_remove(OuthashesV2) + + return count + + async def get_current_gc_mark(self): + async with self.db.begin(): + return await self._get_config("gc-mark") + + async def gc_status(self): + async with self.db.begin(): + gc_mark_subquery = self._get_config_subquery("gc-mark", "") + + result = await self._execute( + select(func.count()) + .select_from(UnihashesV3) + .where(UnihashesV3.gc_mark == gc_mark_subquery) + ) + keep_rows = result.scalar() + + result = await self._execute( + select(func.count()) + .select_from(UnihashesV3) + .where(UnihashesV3.gc_mark != gc_mark_subquery) + ) + remove_rows = result.scalar() + + return (keep_rows, remove_rows, await self._get_config("gc-mark")) + + async def gc_mark(self, mark, condition): + async with self.db.begin(): + await self._set_config("gc-mark", mark) + + where = _make_condition_statement(UnihashesV3, condition) + if not where: + return 0 + + result = await self._execute( + update(UnihashesV3) + .values(gc_mark=self._get_config_subquery("gc-mark", "")) + .where(*where) + ) + return result.rowcount + + async def gc_sweep(self): + async with self.db.begin(): + result = await self._execute( + delete(UnihashesV3).where( + # A sneaky conditional that provides some errant use + # protection: If the config mark is NULL, this will not + # match any rows because No default is specified in the + # select statement + UnihashesV3.gc_mark + != self._get_config_subquery("gc-mark") + ) + ) + await self._set_config("gc-mark", None) + + return result.rowcount + + async def clean_unused(self, oldest): + async with self.db.begin(): + result = await self._execute( + delete(OuthashesV2).where( + OuthashesV2.created < oldest, + ~( + select(UnihashesV3.id) + .where( + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, + ) + .limit(1) + .exists() + ), + ) + ) + return result.rowcount + + async def insert_unihash(self, method, taskhash, unihash): + # Postgres specific ignore on insert duplicate + if self.engine.name == "postgresql": + statement = ( + postgres_insert(UnihashesV3) + .values( + method=method, + taskhash=taskhash, + unihash=unihash, + gc_mark=self._get_config_subquery("gc-mark", ""), + ) + .on_conflict_do_nothing(index_elements=("method", "taskhash")) + ) + else: + statement = insert(UnihashesV3).values( + method=method, + taskhash=taskhash, + unihash=unihash, + gc_mark=self._get_config_subquery("gc-mark", ""), + ) + + try: + async with self.db.begin(): + result = await self._execute(statement) + return result.rowcount != 0 + except IntegrityError: + self.logger.debug( + "%s, %s, %s already in unihash database", method, taskhash, unihash + ) + return False + + async def insert_outhash(self, data): + outhash_columns = set(c.key for c in OuthashesV2.__table__.columns) + + data = {k: v for k, v in data.items() if k in outhash_columns} + + if "created" in data and not isinstance(data["created"], datetime): + data["created"] = datetime.fromisoformat(data["created"]) + + # Postgres specific ignore on insert duplicate + if self.engine.name == "postgresql": + statement = ( + postgres_insert(OuthashesV2) + .values(**data) + .on_conflict_do_nothing( + index_elements=("method", "taskhash", "outhash") + ) + ) + else: + statement = insert(OuthashesV2).values(**data) + + try: + async with self.db.begin(): + result = await self._execute(statement) + return result.rowcount != 0 + except IntegrityError: + self.logger.debug( + "%s, %s already in outhash database", data["method"], data["outhash"] + ) + return False + + async def _get_user(self, username): + async with self.db.begin(): + result = await self._execute( + select( + Users.username, + Users.permissions, + Users.token, + ).where( + Users.username == username, + ) + ) + return result.first() + + async def lookup_user_token(self, username): + row = await self._get_user(username) + if not row: + return None, None + return map_user(row), row.token + + async def lookup_user(self, username): + return map_user(await self._get_user(username)) + + async def set_user_token(self, username, token): + async with self.db.begin(): + result = await self._execute( + update(Users) + .where( + Users.username == username, + ) + .values( + token=token, + ) + ) + return result.rowcount != 0 + + async def set_user_perms(self, username, permissions): + async with self.db.begin(): + result = await self._execute( + update(Users) + .where(Users.username == username) + .values(permissions=" ".join(permissions)) + ) + return result.rowcount != 0 + + async def get_all_users(self): + async with self.db.begin(): + result = await self._execute( + select( + Users.username, + Users.permissions, + ) + ) + return [map_user(row) for row in result] + + async def new_user(self, username, permissions, token): + try: + async with self.db.begin(): + await self._execute( + insert(Users).values( + username=username, + permissions=" ".join(permissions), + token=token, + ) + ) + return True + except IntegrityError as e: + self.logger.debug("Cannot create new user %s: %s", username, e) + return False + + async def delete_user(self, username): + async with self.db.begin(): + result = await self._execute( + delete(Users).where(Users.username == username) + ) + return result.rowcount != 0 + + async def get_usage(self): + usage = {} + async with self.db.begin() as session: + for name, table in Base.metadata.tables.items(): + result = await self._execute( + statement=select(func.count()).select_from(table) + ) + usage[name] = { + "rows": result.scalar(), + } + + return usage + + async def get_query_columns(self): + columns = set() + for table in (UnihashesV3, OuthashesV2): + for c in table.__table__.columns: + if not isinstance(c.type, Text): + continue + columns.add(c.key) + + return list(columns) diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py new file mode 100644 index 000000000..da2e844a0 --- /dev/null +++ b/lib/hashserv/sqlite.py @@ -0,0 +1,562 @@ +#! /usr/bin/env python3 +# +# Copyright (C) 2023 Garmin Ltd. +# +# SPDX-License-Identifier: GPL-2.0-only +# +import sqlite3 +import logging +from contextlib import closing +from . import User + +logger = logging.getLogger("hashserv.sqlite") + +UNIHASH_TABLE_DEFINITION = ( + ("method", "TEXT NOT NULL", "UNIQUE"), + ("taskhash", "TEXT NOT NULL", "UNIQUE"), + ("unihash", "TEXT NOT NULL", ""), + ("gc_mark", "TEXT NOT NULL", ""), +) + +UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION) + +OUTHASH_TABLE_DEFINITION = ( + ("method", "TEXT NOT NULL", "UNIQUE"), + ("taskhash", "TEXT NOT NULL", "UNIQUE"), + ("outhash", "TEXT NOT NULL", "UNIQUE"), + ("created", "DATETIME", ""), + # Optional fields + ("owner", "TEXT", ""), + ("PN", "TEXT", ""), + ("PV", "TEXT", ""), + ("PR", "TEXT", ""), + ("task", "TEXT", ""), + ("outhash_siginfo", "TEXT", ""), +) + +OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION) + +USERS_TABLE_DEFINITION = ( + ("username", "TEXT NOT NULL", "UNIQUE"), + ("token", "TEXT NOT NULL", ""), + ("permissions", "TEXT NOT NULL", ""), +) + +USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION) + + +CONFIG_TABLE_DEFINITION = ( + ("name", "TEXT NOT NULL", "UNIQUE"), + ("value", "TEXT", ""), +) + +CONFIG_TABLE_COLUMNS = tuple(name for name, _, _ in CONFIG_TABLE_DEFINITION) + + +def _make_table(cursor, name, definition): + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS {name} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + {fields} + UNIQUE({unique}) + ) + """.format( + name=name, + fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition), + unique=", ".join( + name for name, _, flags in definition if "UNIQUE" in flags + ), + ) + ) + + +def map_user(row): + if row is None: + return None + return User( + username=row["username"], + permissions=set(row["permissions"].split()), + ) + + +def _make_condition_statement(columns, condition): + where = {} + for c in columns: + if c in condition and condition[c] is not None: + where[c] = condition[c] + + return where, " AND ".join("%s=:%s" % (k, k) for k in where.keys()) + + +def _get_sqlite_version(cursor): + cursor.execute("SELECT sqlite_version()") + + version = [] + for v in cursor.fetchone()[0].split("."): + try: + version.append(int(v)) + except ValueError: + version.append(v) + + return tuple(version) + + +def _schema_table_name(version): + if version >= (3, 33): + return "sqlite_schema" + + return "sqlite_master" + + +class DatabaseEngine(object): + def __init__(self, dbname, sync): + self.dbname = dbname + self.logger = logger + self.sync = sync + + async def create(self): + db = sqlite3.connect(self.dbname) + db.row_factory = sqlite3.Row + + with closing(db.cursor()) as cursor: + _make_table(cursor, "unihashes_v3", UNIHASH_TABLE_DEFINITION) + _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION) + _make_table(cursor, "users", USERS_TABLE_DEFINITION) + _make_table(cursor, "config", CONFIG_TABLE_DEFINITION) + + cursor.execute("PRAGMA journal_mode = WAL") + cursor.execute( + "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF") + ) + + # Drop old indexes + cursor.execute("DROP INDEX IF EXISTS taskhash_lookup") + cursor.execute("DROP INDEX IF EXISTS outhash_lookup") + cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2") + cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2") + cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v3") + + # TODO: Upgrade from tasks_v2? + cursor.execute("DROP TABLE IF EXISTS tasks_v2") + + # Create new indexes + cursor.execute( + "CREATE INDEX IF NOT EXISTS taskhash_lookup_v4 ON unihashes_v3 (method, taskhash)" + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS unihash_lookup_v1 ON unihashes_v3 (unihash)" + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)" + ) + cursor.execute("CREATE INDEX IF NOT EXISTS config_lookup ON config (name)") + + sqlite_version = _get_sqlite_version(cursor) + + cursor.execute( + f""" + SELECT name FROM {_schema_table_name(sqlite_version)} WHERE type = 'table' AND name = 'unihashes_v2' + """ + ) + if cursor.fetchone(): + self.logger.info("Upgrading Unihashes V2 -> V3...") + cursor.execute( + """ + INSERT INTO unihashes_v3 (id, method, unihash, taskhash, gc_mark) + SELECT id, method, unihash, taskhash, '' FROM unihashes_v2 + """ + ) + cursor.execute("DROP TABLE unihashes_v2") + db.commit() + self.logger.info("Upgrade complete") + + def connect(self, logger): + return Database(logger, self.dbname, self.sync) + + +class Database(object): + def __init__(self, logger, dbname, sync): + self.dbname = dbname + self.logger = logger + + self.db = sqlite3.connect(self.dbname) + self.db.row_factory = sqlite3.Row + + with closing(self.db.cursor()) as cursor: + cursor.execute("PRAGMA journal_mode = WAL") + cursor.execute( + "PRAGMA synchronous = %s" % ("NORMAL" if sync else "OFF") + ) + + self.sqlite_version = _get_sqlite_version(cursor) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + async def _set_config(self, cursor, name, value): + cursor.execute( + """ + INSERT OR REPLACE INTO config (id, name, value) VALUES + ((SELECT id FROM config WHERE name=:name), :name, :value) + """, + { + "name": name, + "value": value, + }, + ) + + async def _get_config(self, cursor, name): + cursor.execute( + "SELECT value FROM config WHERE name=:name", + { + "name": name, + }, + ) + row = cursor.fetchone() + if row is None: + return None + return row["value"] + + async def close(self): + self.db.close() + + async def get_unihash_by_taskhash_full(self, method, taskhash): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2 + INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash + WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash + ORDER BY outhashes_v2.created ASC + LIMIT 1 + """, + { + "method": method, + "taskhash": taskhash, + }, + ) + return cursor.fetchone() + + async def get_unihash_by_outhash(self, method, outhash): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2 + INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash + WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash + ORDER BY outhashes_v2.created ASC + LIMIT 1 + """, + { + "method": method, + "outhash": outhash, + }, + ) + return cursor.fetchone() + + async def unihash_exists(self, unihash): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT * FROM unihashes_v3 WHERE unihash=:unihash + LIMIT 1 + """, + { + "unihash": unihash, + }, + ) + return cursor.fetchone() is not None + + async def get_outhash(self, method, outhash): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT * FROM outhashes_v2 + WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash + ORDER BY outhashes_v2.created ASC + LIMIT 1 + """, + { + "method": method, + "outhash": outhash, + }, + ) + return cursor.fetchone() + + async def get_equivalent_for_outhash(self, method, outhash, taskhash): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT outhashes_v2.taskhash AS taskhash, unihashes_v3.unihash AS unihash FROM outhashes_v2 + INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash + -- Select any matching output hash except the one we just inserted + WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash + -- Pick the oldest hash + ORDER BY outhashes_v2.created ASC + LIMIT 1 + """, + { + "method": method, + "outhash": outhash, + "taskhash": taskhash, + }, + ) + return cursor.fetchone() + + async def get_equivalent(self, method, taskhash): + with closing(self.db.cursor()) as cursor: + cursor.execute( + "SELECT taskhash, method, unihash FROM unihashes_v3 WHERE method=:method AND taskhash=:taskhash", + { + "method": method, + "taskhash": taskhash, + }, + ) + return cursor.fetchone() + + async def remove(self, condition): + def do_remove(columns, table_name, cursor): + where, clause = _make_condition_statement(columns, condition) + if where: + query = f"DELETE FROM {table_name} WHERE {clause}" + cursor.execute(query, where) + return cursor.rowcount + + return 0 + + count = 0 + with closing(self.db.cursor()) as cursor: + count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor) + count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v3", cursor) + self.db.commit() + + return count + + async def get_current_gc_mark(self): + with closing(self.db.cursor()) as cursor: + return await self._get_config(cursor, "gc-mark") + + async def gc_status(self): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT COUNT() FROM unihashes_v3 WHERE + gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') + """ + ) + keep_rows = cursor.fetchone()[0] + + cursor.execute( + """ + SELECT COUNT() FROM unihashes_v3 WHERE + gc_mark!=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') + """ + ) + remove_rows = cursor.fetchone()[0] + + current_mark = await self._get_config(cursor, "gc-mark") + + return (keep_rows, remove_rows, current_mark) + + async def gc_mark(self, mark, condition): + with closing(self.db.cursor()) as cursor: + await self._set_config(cursor, "gc-mark", mark) + + where, clause = _make_condition_statement(UNIHASH_TABLE_COLUMNS, condition) + + new_rows = 0 + if where: + cursor.execute( + f""" + UPDATE unihashes_v3 SET + gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') + WHERE {clause} + """, + where, + ) + new_rows = cursor.rowcount + + self.db.commit() + return new_rows + + async def gc_sweep(self): + with closing(self.db.cursor()) as cursor: + # NOTE: COALESCE is not used in this query so that if the current + # mark is NULL, nothing will happen + cursor.execute( + """ + DELETE FROM unihashes_v3 WHERE + gc_mark!=(SELECT value FROM config WHERE name='gc-mark') + """ + ) + count = cursor.rowcount + await self._set_config(cursor, "gc-mark", None) + + self.db.commit() + return count + + async def clean_unused(self, oldest): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS ( + SELECT unihashes_v3.id FROM unihashes_v3 WHERE unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash LIMIT 1 + ) + """, + { + "oldest": oldest, + }, + ) + self.db.commit() + return cursor.rowcount + + async def insert_unihash(self, method, taskhash, unihash): + with closing(self.db.cursor()) as cursor: + prevrowid = cursor.lastrowid + cursor.execute( + """ + INSERT OR IGNORE INTO unihashes_v3 (method, taskhash, unihash, gc_mark) VALUES + ( + :method, + :taskhash, + :unihash, + COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') + ) + """, + { + "method": method, + "taskhash": taskhash, + "unihash": unihash, + }, + ) + self.db.commit() + return cursor.lastrowid != prevrowid + + async def insert_outhash(self, data): + data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS} + keys = sorted(data.keys()) + query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format( + fields=", ".join(keys), + values=", ".join(":" + k for k in keys), + ) + with closing(self.db.cursor()) as cursor: + prevrowid = cursor.lastrowid + cursor.execute(query, data) + self.db.commit() + return cursor.lastrowid != prevrowid + + def _get_user(self, username): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT username, permissions, token FROM users WHERE username=:username + """, + { + "username": username, + }, + ) + return cursor.fetchone() + + async def lookup_user_token(self, username): + row = self._get_user(username) + if row is None: + return None, None + return map_user(row), row["token"] + + async def lookup_user(self, username): + return map_user(self._get_user(username)) + + async def set_user_token(self, username, token): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + UPDATE users SET token=:token WHERE username=:username + """, + { + "username": username, + "token": token, + }, + ) + self.db.commit() + return cursor.rowcount != 0 + + async def set_user_perms(self, username, permissions): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + UPDATE users SET permissions=:permissions WHERE username=:username + """, + { + "username": username, + "permissions": " ".join(permissions), + }, + ) + self.db.commit() + return cursor.rowcount != 0 + + async def get_all_users(self): + with closing(self.db.cursor()) as cursor: + cursor.execute("SELECT username, permissions FROM users") + return [map_user(r) for r in cursor.fetchall()] + + async def new_user(self, username, permissions, token): + with closing(self.db.cursor()) as cursor: + try: + cursor.execute( + """ + INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions) + """, + { + "username": username, + "token": token, + "permissions": " ".join(permissions), + }, + ) + self.db.commit() + return True + except sqlite3.IntegrityError: + return False + + async def delete_user(self, username): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + DELETE FROM users WHERE username=:username + """, + { + "username": username, + }, + ) + self.db.commit() + return cursor.rowcount != 0 + + async def get_usage(self): + usage = {} + with closing(self.db.cursor()) as cursor: + cursor.execute( + f""" + SELECT name FROM {_schema_table_name(self.sqlite_version)} WHERE type = 'table' AND name NOT LIKE 'sqlite_%' + """ + ) + for row in cursor.fetchall(): + cursor.execute( + """ + SELECT COUNT() FROM %s + """ + % row["name"], + ) + usage[row["name"]] = { + "rows": cursor.fetchone()[0], + } + return usage + + async def get_query_columns(self): + columns = set() + for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION: + if typ.startswith("TEXT"): + columns.add(name) + return list(columns) diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py index e2b762dbf..0809453cf 100644 --- a/lib/hashserv/tests.py +++ b/lib/hashserv/tests.py @@ -6,6 +6,9 @@ # from . import create_server, create_client +from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS +from bb.asyncrpc import InvokeError +from .client import ClientPool import hashlib import logging import multiprocessing @@ -15,46 +18,80 @@ import tempfile import threading import unittest import socket - -def _run_server(server, idx): - # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', - # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') - sys.stdout = open('bbhashserv-%d.log' % idx, 'w') +import time +import signal +import subprocess +import json +import re +from pathlib import Path + + +THIS_DIR = Path(__file__).parent +BIN_DIR = THIS_DIR.parent.parent / "bin" + +def server_prefunc(server, idx): + logging.basicConfig(level=logging.DEBUG, filename='bbhashserv-%d.log' % idx, filemode='w', + format='%(levelname)s %(filename)s:%(lineno)d %(message)s') + server.logger.debug("Running server %d" % idx) + sys.stdout = open('bbhashserv-stdout-%d.log' % idx, 'w') sys.stderr = sys.stdout - server.serve_forever() - class HashEquivalenceTestSetup(object): METHOD = 'TestMethod' server_index = 0 + client_index = 0 - def start_server(self, dbpath=None, upstream=None, read_only=False): + def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None): self.server_index += 1 if dbpath is None: - dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) + dbpath = self.make_dbpath() + + def cleanup_server(server): + if server.process.exitcode is not None: + return - def cleanup_thread(thread): - thread.terminate() - thread.join() + server.process.terminate() + server.process.join() server = create_server(self.get_server_addr(self.server_index), dbpath, upstream=upstream, - read_only=read_only) + read_only=read_only, + anon_perms=anon_perms, + admin_username=admin_username, + admin_password=admin_password) server.dbpath = dbpath - server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) - server.thread.start() - self.addCleanup(cleanup_thread, server.thread) + server.serve_as_process(prefunc=prefunc, args=(self.server_index,)) + self.addCleanup(cleanup_server, server) + + return server + + def make_dbpath(self): + return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) + def start_client(self, server_address, username=None, password=None): def cleanup_client(client): client.close() - client = create_client(server.address) + client = create_client(server_address, username=username, password=password) self.addCleanup(cleanup_client, client) - return (client, server) + return client + + def start_test_server(self): + self.server = self.start_server() + return self.server.address + + def start_auth_server(self): + auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password") + self.auth_server_address = auth_server.address + self.admin_client = self.start_client(auth_server.address, username="admin", password="password") + return self.admin_client + + def auth_client(self, user): + return self.start_client(self.auth_server_address, user["username"], user["token"]) def setUp(self): if sys.version_info < (3, 5, 0): @@ -63,24 +100,82 @@ class HashEquivalenceTestSetup(object): self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') self.addCleanup(self.temp_dir.cleanup) - (self.client, self.server) = self.start_server() + self.server_address = self.start_test_server() + + self.client = self.start_client(self.server_address) def assertClientGetHash(self, client, taskhash, unihash): result = client.get_unihash(self.METHOD, taskhash) self.assertEqual(result, unihash) + def assertUserPerms(self, user, permissions): + with self.auth_client(user) as client: + info = client.get_user() + self.assertEqual(info, { + "username": user["username"], + "permissions": permissions, + }) -class HashEquivalenceCommonTests(object): - def test_create_hash(self): + def assertUserCanAuth(self, user): + with self.start_client(self.auth_server_address) as client: + client.auth(user["username"], user["token"]) + + def assertUserCannotAuth(self, user): + with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError): + client.auth(user["username"], user["token"]) + + def create_test_hash(self, client): # Simple test that hashes can be created taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' - self.assertClientGetHash(self.client, taskhash, None) + self.assertClientGetHash(client, taskhash, None) - result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + result = client.report_unihash(taskhash, self.METHOD, outhash, unihash) self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + return taskhash, outhash, unihash + + def run_hashclient(self, args, **kwargs): + try: + p = subprocess.run( + [BIN_DIR / "bitbake-hashclient"] + args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + encoding="utf-8", + **kwargs + ) + except subprocess.CalledProcessError as e: + print(e.output) + raise e + + print(p.stdout) + return p + + +class HashEquivalenceCommonTests(object): + def auth_perms(self, *permissions): + self.client_index += 1 + user = self.create_user(f"user-{self.client_index}", permissions) + return self.auth_client(user) + + def create_user(self, username, permissions, *, client=None): + def remove_user(username): + try: + self.admin_client.delete_user(username) + except bb.asyncrpc.InvokeError: + pass + + if client is None: + client = self.admin_client + + user = client.new_user(username, permissions) + self.addCleanup(remove_user, username) + + return user + + def test_create_hash(self): + return self.create_test_hash(self.client) def test_create_equivalent(self): # Tests that a second reported task with the same outhash will be @@ -122,6 +217,57 @@ class HashEquivalenceCommonTests(object): self.assertClientGetHash(self.client, taskhash, unihash) + def test_remove_taskhash(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + result = self.client.remove({"taskhash": taskhash}) + self.assertGreater(result["count"], 0) + self.assertClientGetHash(self.client, taskhash, None) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + + def test_remove_unihash(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + result = self.client.remove({"unihash": unihash}) + self.assertGreater(result["count"], 0) + self.assertClientGetHash(self.client, taskhash, None) + + def test_remove_outhash(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + result = self.client.remove({"outhash": outhash}) + self.assertGreater(result["count"], 0) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + + def test_remove_method(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + result = self.client.remove({"method": self.METHOD}) + self.assertGreater(result["count"], 0) + self.assertClientGetHash(self.client, taskhash, None) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + + def test_clean_unused(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + + # Clean the database, which should not remove anything because all hashes an in-use + result = self.client.clean_unused(0) + self.assertEqual(result["count"], 0) + self.assertClientGetHash(self.client, taskhash, unihash) + + # Remove the unihash. The row in the outhash table should still be present + self.client.remove({"unihash": unihash}) + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False) + self.assertIsNotNone(result_outhash) + + # Now clean with no minimum age which will remove the outhash + result = self.client.clean_unused(0) + self.assertEqual(result["count"], 1) + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False) + self.assertIsNone(result_outhash) + def test_huge_message(self): # Simple test that hashes can be created taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' @@ -137,16 +283,21 @@ class HashEquivalenceCommonTests(object): }) self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') - result = self.client.get_taskhash(self.METHOD, taskhash, True) - self.assertEqual(result['taskhash'], taskhash) - self.assertEqual(result['unihash'], unihash) - self.assertEqual(result['method'], self.METHOD) - self.assertEqual(result['outhash'], outhash) - self.assertEqual(result['outhash_siginfo'], siginfo) + result_unihash = self.client.get_taskhash(self.METHOD, taskhash, True) + self.assertEqual(result_unihash['taskhash'], taskhash) + self.assertEqual(result_unihash['unihash'], unihash) + self.assertEqual(result_unihash['method'], self.METHOD) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertEqual(result_outhash['taskhash'], taskhash) + self.assertEqual(result_outhash['method'], self.METHOD) + self.assertEqual(result_outhash['unihash'], unihash) + self.assertEqual(result_outhash['outhash'], outhash) + self.assertEqual(result_outhash['outhash_siginfo'], siginfo) def test_stress(self): def query_server(failures): - client = Client(self.server.address) + client = Client(self.server_address) try: for i in range(1000): taskhash = hashlib.sha256() @@ -185,8 +336,10 @@ class HashEquivalenceCommonTests(object): # the side client. It also verifies that the results are pulled into # the downstream database by checking that the downstream and side servers # match after the downstream is done waiting for all backfill tasks - (down_client, down_server) = self.start_server(upstream=self.server.address) - (side_client, side_server) = self.start_server(dbpath=down_server.dbpath) + down_server = self.start_server(upstream=self.server_address) + down_client = self.start_client(down_server.address) + side_server = self.start_server(dbpath=down_server.dbpath) + side_client = self.start_client(side_server.address) def check_hash(taskhash, unihash, old_sidehash): nonlocal down_client @@ -257,15 +410,57 @@ class HashEquivalenceCommonTests(object): result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6) self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream') + # Tests read through from server with + taskhash7 = '9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74' + outhash7 = '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69' + unihash7 = '05d2a63c81e32f0a36542ca677e8ad852365c538' + self.client.report_unihash(taskhash7, self.METHOD, outhash7, unihash7) + + result = down_client.get_taskhash(self.METHOD, taskhash7, True) + self.assertEqual(result['unihash'], unihash7, 'Server failed to copy unihash from upstream') + self.assertEqual(result['outhash'], outhash7, 'Server failed to copy unihash from upstream') + self.assertEqual(result['taskhash'], taskhash7, 'Server failed to copy unihash from upstream') + self.assertEqual(result['method'], self.METHOD) + + taskhash8 = '86978a4c8c71b9b487330b0152aade10c1ee58aa' + outhash8 = 'ca8c128e9d9e4a28ef24d0508aa20b5cf880604eacd8f65c0e366f7e0cc5fbcf' + unihash8 = 'd8bcf25369d40590ad7d08c84d538982f2023e01' + self.client.report_unihash(taskhash8, self.METHOD, outhash8, unihash8) + + result = down_client.get_outhash(self.METHOD, outhash8, taskhash8) + self.assertEqual(result['unihash'], unihash8, 'Server failed to copy unihash from upstream') + self.assertEqual(result['outhash'], outhash8, 'Server failed to copy unihash from upstream') + self.assertEqual(result['taskhash'], taskhash8, 'Server failed to copy unihash from upstream') + self.assertEqual(result['method'], self.METHOD) + + taskhash9 = 'ae6339531895ddf5b67e663e6a374ad8ec71d81c' + outhash9 = 'afc78172c81880ae10a1fec994b5b4ee33d196a001a1b66212a15ebe573e00b5' + unihash9 = '6662e699d6e3d894b24408ff9a4031ef9b038ee8' + self.client.report_unihash(taskhash9, self.METHOD, outhash9, unihash9) + + result = down_client.get_taskhash(self.METHOD, taskhash9, False) + self.assertEqual(result['unihash'], unihash9, 'Server failed to copy unihash from upstream') + self.assertEqual(result['taskhash'], taskhash9, 'Server failed to copy unihash from upstream') + self.assertEqual(result['method'], self.METHOD) + + def test_unihash_exsits(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + self.assertTrue(self.client.unihash_exists(unihash)) + self.assertFalse(self.client.unihash_exists('6662e699d6e3d894b24408ff9a4031ef9b038ee8')) + def test_ro_server(self): - (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True) + rw_server = self.start_server() + rw_client = self.start_client(rw_server.address) + + ro_server = self.start_server(dbpath=rw_server.dbpath, read_only=True) + ro_client = self.start_client(ro_server.address) # Report a hash via the read-write server taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' - result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + result = rw_client.report_unihash(taskhash, self.METHOD, outhash, unihash) self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') # Check the hash via the read-only server @@ -276,11 +471,940 @@ class HashEquivalenceCommonTests(object): outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824' - with self.assertRaises(ConnectionError): - ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + result = ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertEqual(result['unihash'], unihash2) # Ensure that the database was not modified + self.assertClientGetHash(rw_client, taskhash2, None) + + + def test_slow_server_start(self): + # Ensures that the server will exit correctly even if it gets a SIGTERM + # before entering the main loop + + event = multiprocessing.Event() + + def prefunc(server, idx): + nonlocal event + server_prefunc(server, idx) + event.wait() + + def do_nothing(signum, frame): + pass + + old_signal = signal.signal(signal.SIGTERM, do_nothing) + self.addCleanup(signal.signal, signal.SIGTERM, old_signal) + + server = self.start_server(prefunc=prefunc) + server.process.terminate() + time.sleep(30) + event.set() + server.process.join(300) + self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!") + + def test_diverging_report_race(self): + # Tests that a reported task will correctly pick up an updated unihash + + # This is a baseline report added to the database to ensure that there + # is something to match against as equivalent + outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa' + taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab' + unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab' + result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1) + + # Add a report that is equivalent to Task 1. It should ignore the + # provided unihash and report the unihash from task 1 + taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273' + unihash2 = taskhash2 + result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2) + self.assertEqual(result['unihash'], unihash1) + + # Add another report for Task 2, but with a different outhash (e.g. the + # task is non-deterministic). It should still be marked with the Task 1 + # unihash because it has the Task 2 taskhash, which is equivalent to + # Task 1 + outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c' + result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2) + self.assertEqual(result['unihash'], unihash1) + + + def test_diverging_report_reverse_race(self): + # Same idea as the previous test, but Tasks 2 and 3 are reported in + # reverse order the opposite order + + outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa' + taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab' + unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab' + result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1) + + taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273' + unihash2 = taskhash2 + + # Report Task 3 first. Since there is nothing else in the database it + # will use the client provided unihash + outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c' + result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2) + self.assertEqual(result['unihash'], unihash2) + + # Report Task 2. This is equivalent to Task 1 but there is already a mapping for + # taskhash2 so it will report unihash2 + result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2) + self.assertEqual(result['unihash'], unihash2) + + # The originally reported unihash for Task 3 should be unchanged even if it + # shares a taskhash with Task 2 + self.assertClientGetHash(self.client, taskhash2, unihash2) + + + def test_client_pool_get_unihashes(self): + TEST_INPUT = ( + # taskhash outhash unihash + ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'), + # Duplicated taskhash with multiple output hashes and unihashes. + ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'), + # Equivalent hash + ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"), + ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"), + ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'), + ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'), + ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'), + ) + EXTRA_QUERIES = ( + "6b6be7a84ab179b4240c4302518dc3f6", + ) + + with ClientPool(self.server_address, 10) as client_pool: + for taskhash, outhash, unihash in TEST_INPUT: + self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + + query = {idx: (self.METHOD, data[0]) for idx, data in enumerate(TEST_INPUT)} + for idx, taskhash in enumerate(EXTRA_QUERIES): + query[idx + len(TEST_INPUT)] = (self.METHOD, taskhash) + + result = client_pool.get_unihashes(query) + + self.assertDictEqual(result, { + 0: "218e57509998197d570e2c98512d0105985dffc9", + 1: "218e57509998197d570e2c98512d0105985dffc9", + 2: "218e57509998197d570e2c98512d0105985dffc9", + 3: "3b5d3d83f07f259e9086fcb422c855286e18a57d", + 4: "f46d3fbb439bd9b921095da657a4de906510d2cd", + 5: "f46d3fbb439bd9b921095da657a4de906510d2cd", + 6: "05d2a63c81e32f0a36542ca677e8ad852365c538", + 7: None, + }) + + def test_client_pool_unihash_exists(self): + TEST_INPUT = ( + # taskhash outhash unihash + ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'), + # Duplicated taskhash with multiple output hashes and unihashes. + ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'), + # Equivalent hash + ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"), + ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"), + ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'), + ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'), + ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'), + ) + EXTRA_QUERIES = ( + "6b6be7a84ab179b4240c4302518dc3f6", + ) + + result_unihashes = set() + + + with ClientPool(self.server_address, 10) as client_pool: + for taskhash, outhash, unihash in TEST_INPUT: + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + result_unihashes.add(result["unihash"]) + + query = {} + expected = {} + + for _, _, unihash in TEST_INPUT: + idx = len(query) + query[idx] = unihash + expected[idx] = unihash in result_unihashes + + + for unihash in EXTRA_QUERIES: + idx = len(query) + query[idx] = unihash + expected[idx] = False + + result = client_pool.unihashes_exist(query) + self.assertDictEqual(result, expected) + + + def test_auth_read_perms(self): + admin_client = self.start_auth_server() + + # Create hashes with non-authenticated server + taskhash, outhash, unihash = self.create_test_hash(self.client) + + # Validate hash can be retrieved using authenticated client + with self.auth_perms("@read") as client: + self.assertClientGetHash(client, taskhash, unihash) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + self.assertClientGetHash(client, taskhash, unihash) + + def test_auth_report_perms(self): + admin_client = self.start_auth_server() + + # Without read permission, the user is completely denied + with self.auth_perms() as client, self.assertRaises(InvokeError): + self.create_test_hash(client) + + # Read permission allows the call to succeed, but it doesn't record + # anythin in the database + with self.auth_perms("@read") as client: + taskhash, outhash, unihash = self.create_test_hash(client) + self.assertClientGetHash(client, taskhash, None) + + # Report permission alone is insufficient + with self.auth_perms("@report") as client, self.assertRaises(InvokeError): + self.create_test_hash(client) + + # Read and report permission actually modify the database + with self.auth_perms("@read", "@report") as client: + taskhash, outhash, unihash = self.create_test_hash(client) + self.assertClientGetHash(client, taskhash, unihash) + + def test_auth_no_token_refresh_from_anon_user(self): + self.start_auth_server() + + with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError): + client.refresh_token() + + def test_auth_self_token_refresh(self): + admin_client = self.start_auth_server() + + # Create a new user with no permissions + user = self.create_user("test-user", []) + + with self.auth_client(user) as client: + new_user = client.refresh_token() + + self.assertEqual(user["username"], new_user["username"]) + self.assertNotEqual(user["token"], new_user["token"]) + self.assertUserCanAuth(new_user) + self.assertUserCannotAuth(user) + + # Explicitly specifying with your own username is fine also + with self.auth_client(new_user) as client: + new_user2 = client.refresh_token(user["username"]) + + self.assertEqual(user["username"], new_user2["username"]) + self.assertNotEqual(user["token"], new_user2["token"]) + self.assertUserCanAuth(new_user2) + self.assertUserCannotAuth(new_user) + self.assertUserCannotAuth(user) + + def test_auth_token_refresh(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", []) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.refresh_token(user["username"]) + + with self.auth_perms("@user-admin") as client: + new_user = client.refresh_token(user["username"]) + + self.assertEqual(user["username"], new_user["username"]) + self.assertNotEqual(user["token"], new_user["token"]) + self.assertUserCanAuth(new_user) + self.assertUserCannotAuth(user) + + def test_auth_self_get_user(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", []) + user_info = user.copy() + del user_info["token"] + + with self.auth_client(user) as client: + info = client.get_user() + self.assertEqual(info, user_info) + + # Explicitly asking for your own username is fine also + info = client.get_user(user["username"]) + self.assertEqual(info, user_info) + + def test_auth_get_user(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", []) + user_info = user.copy() + del user_info["token"] + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.get_user(user["username"]) + + with self.auth_perms("@user-admin") as client: + info = client.get_user(user["username"]) + self.assertEqual(info, user_info) + + info = client.get_user("nonexist-user") + self.assertIsNone(info) + + def test_auth_reconnect(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", []) + user_info = user.copy() + del user_info["token"] + + with self.auth_client(user) as client: + info = client.get_user() + self.assertEqual(info, user_info) + + client.disconnect() + + info = client.get_user() + self.assertEqual(info, user_info) + + def test_auth_delete_user(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", []) + + # self service + with self.auth_client(user) as client: + client.delete_user(user["username"]) + + self.assertIsNone(admin_client.get_user(user["username"])) + user = self.create_user("test-user", []) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.delete_user(user["username"]) + + with self.auth_perms("@user-admin") as client: + client.delete_user(user["username"]) + + # User doesn't exist, so even though the permission is correct, it's an + # error + with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError): + client.delete_user(user["username"]) + + def test_auth_set_user_perms(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", []) + + self.assertUserPerms(user, []) + + # No self service to change permissions + with self.auth_client(user) as client, self.assertRaises(InvokeError): + client.set_user_perms(user["username"], ["@all"]) + self.assertUserPerms(user, []) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.set_user_perms(user["username"], ["@all"]) + self.assertUserPerms(user, []) + + with self.auth_perms("@user-admin") as client: + client.set_user_perms(user["username"], ["@all"]) + self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS))) + + # Bad permissions + with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError): + client.set_user_perms(user["username"], ["@this-is-not-a-permission"]) + self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS))) + + def test_auth_get_all_users(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", []) + + with self.auth_client(user) as client, self.assertRaises(InvokeError): + client.get_all_users() + + # Give the test user the correct permission + admin_client.set_user_perms(user["username"], ["@user-admin"]) + + with self.auth_client(user) as client: + all_users = client.get_all_users() + + # Convert to a dictionary for easier comparison + all_users = {u["username"]: u for u in all_users} + + self.assertEqual(all_users, + { + "admin": { + "username": "admin", + "permissions": sorted(list(ALL_PERMISSIONS)), + }, + "test-user": { + "username": "test-user", + "permissions": ["@user-admin"], + } + } + ) + + def test_auth_new_user(self): + self.start_auth_server() + + permissions = ["@read", "@report", "@db-admin", "@user-admin"] + permissions.sort() + + with self.auth_perms() as client, self.assertRaises(InvokeError): + self.create_user("test-user", permissions, client=client) + + with self.auth_perms("@user-admin") as client: + user = self.create_user("test-user", permissions, client=client) + self.assertIn("token", user) + self.assertEqual(user["username"], "test-user") + self.assertEqual(user["permissions"], permissions) + + def test_auth_become_user(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", ["@read", "@report"]) + user_info = user.copy() + del user_info["token"] + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.become_user(user["username"]) + + with self.auth_perms("@user-admin") as client: + become = client.become_user(user["username"]) + self.assertEqual(become, user_info) + + info = client.get_user() + self.assertEqual(info, user_info) + + # Verify become user is preserved across disconnect + client.disconnect() + + info = client.get_user() + self.assertEqual(info, user_info) + + # test-user doesn't have become_user permissions, so this should + # not work + with self.assertRaises(InvokeError): + client.become_user(user["username"]) + + # No self-service of become + with self.auth_client(user) as client, self.assertRaises(InvokeError): + client.become_user(user["username"]) + + # Give test user permissions to become + admin_client.set_user_perms(user["username"], ["@user-admin"]) + + # It's possible to become yourself (effectively a noop) + with self.auth_perms("@user-admin") as client: + become = client.become_user(client.username) + + def test_auth_gc(self): + admin_client = self.start_auth_server() + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.gc_mark("ABC", {"unihash": "123"}) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.gc_status() + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.gc_sweep("ABC") + + with self.auth_perms("@db-admin") as client: + client.gc_mark("ABC", {"unihash": "123"}) + + with self.auth_perms("@db-admin") as client: + client.gc_status() + + with self.auth_perms("@db-admin") as client: + client.gc_sweep("ABC") + + def test_get_db_usage(self): + usage = self.client.get_db_usage() + + self.assertTrue(isinstance(usage, dict)) + for name in usage.keys(): + self.assertTrue(isinstance(usage[name], dict)) + self.assertIn("rows", usage[name]) + self.assertTrue(isinstance(usage[name]["rows"], int)) + + def test_get_db_query_columns(self): + columns = self.client.get_db_query_columns() + + self.assertTrue(isinstance(columns, list)) + self.assertTrue(len(columns) > 0) + + for col in columns: + self.client.remove({col: ""}) + + def test_auth_is_owner(self): + admin_client = self.start_auth_server() + + user = self.create_user("test-user", ["@read", "@report"]) + with self.auth_client(user) as client: + taskhash, outhash, unihash = self.create_test_hash(client) + data = client.get_taskhash(self.METHOD, taskhash, True) + self.assertEqual(data["owner"], user["username"]) + + def test_gc(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Mark the first unihash to be kept + ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) + + # Second hash is still there; mark doesn't delete hashes + self.assertClientGetHash(self.client, taskhash2, unihash2) + + ret = self.client.gc_sweep("ABC") + self.assertEqual(ret, {"count": 1}) + + # Hash is gone. Taskhash is returned for second hash self.assertClientGetHash(self.client, taskhash2, None) + # First hash is still present + self.assertClientGetHash(self.client, taskhash, unihash) + + def test_gc_switch_mark(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Mark the first unihash to be kept + ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) + + # Second hash is still there; mark doesn't delete hashes + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Switch to a different mark and mark the second hash. This will start + # a new collection cycle + ret = self.client.gc_mark("DEF", {"unihash": unihash2, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "DEF", "keep": 1, "remove": 1}) + + # Both hashes are still present + self.assertClientGetHash(self.client, taskhash2, unihash2) + self.assertClientGetHash(self.client, taskhash, unihash) + + # Sweep with the new mark + ret = self.client.gc_sweep("DEF") + self.assertEqual(ret, {"count": 1}) + + # First hash is gone, second is kept + self.assertClientGetHash(self.client, taskhash2, unihash2) + self.assertClientGetHash(self.client, taskhash, None) + + def test_gc_switch_sweep_mark(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Mark the first unihash to be kept + ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) + + # Sweeping with a different mark raises an error + with self.assertRaises(InvokeError): + self.client.gc_sweep("DEF") + + # Both hashes are present + self.assertClientGetHash(self.client, taskhash2, unihash2) + self.assertClientGetHash(self.client, taskhash, unihash) + + def test_gc_new_hashes(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + # Start a new garbage collection + ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 0}) + + # Add second hash. It should inherit the mark from the current garbage + # collection operation + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Sweep should remove nothing + ret = self.client.gc_sweep("ABC") + self.assertEqual(ret, {"count": 0}) + + # Both hashes are present + self.assertClientGetHash(self.client, taskhash2, unihash2) + self.assertClientGetHash(self.client, taskhash, unihash) + + +class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase): + def get_server_addr(self, server_idx): + return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx) + + def test_get(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + + p = self.run_hashclient(["--address", self.server_address, "get", self.METHOD, taskhash]) + data = json.loads(p.stdout) + self.assertEqual(data["unihash"], unihash) + self.assertEqual(data["outhash"], outhash) + self.assertEqual(data["taskhash"], taskhash) + self.assertEqual(data["method"], self.METHOD) + + def test_get_outhash(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + + p = self.run_hashclient(["--address", self.server_address, "get-outhash", self.METHOD, outhash, taskhash]) + data = json.loads(p.stdout) + self.assertEqual(data["unihash"], unihash) + self.assertEqual(data["outhash"], outhash) + self.assertEqual(data["taskhash"], taskhash) + self.assertEqual(data["method"], self.METHOD) + + def test_stats(self): + p = self.run_hashclient(["--address", self.server_address, "stats"], check=True) + json.loads(p.stdout) + + def test_stress(self): + self.run_hashclient(["--address", self.server_address, "stress"], check=True) + + def test_unihash_exsits(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + + p = self.run_hashclient([ + "--address", self.server_address, + "unihash-exists", unihash, + ], check=True) + self.assertEqual(p.stdout.strip(), "true") + + p = self.run_hashclient([ + "--address", self.server_address, + "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8', + ], check=True) + self.assertEqual(p.stdout.strip(), "false") + + def test_unihash_exsits_quiet(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + + p = self.run_hashclient([ + "--address", self.server_address, + "unihash-exists", unihash, + "--quiet", + ]) + self.assertEqual(p.returncode, 0) + self.assertEqual(p.stdout.strip(), "") + + p = self.run_hashclient([ + "--address", self.server_address, + "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8', + "--quiet", + ]) + self.assertEqual(p.returncode, 1) + self.assertEqual(p.stdout.strip(), "") + + def test_remove_taskhash(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + self.run_hashclient([ + "--address", self.server_address, + "remove", + "--where", "taskhash", taskhash, + ], check=True) + self.assertClientGetHash(self.client, taskhash, None) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + + def test_remove_unihash(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + self.run_hashclient([ + "--address", self.server_address, + "remove", + "--where", "unihash", unihash, + ], check=True) + self.assertClientGetHash(self.client, taskhash, None) + + def test_remove_outhash(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + self.run_hashclient([ + "--address", self.server_address, + "remove", + "--where", "outhash", outhash, + ], check=True) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + + def test_remove_method(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + self.run_hashclient([ + "--address", self.server_address, + "remove", + "--where", "method", self.METHOD, + ], check=True) + self.assertClientGetHash(self.client, taskhash, None) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + + def test_clean_unused(self): + taskhash, outhash, unihash = self.create_test_hash(self.client) + + # Clean the database, which should not remove anything because all hashes an in-use + self.run_hashclient([ + "--address", self.server_address, + "clean-unused", "0", + ], check=True) + self.assertClientGetHash(self.client, taskhash, unihash) + + # Remove the unihash. The row in the outhash table should still be present + self.run_hashclient([ + "--address", self.server_address, + "remove", + "--where", "unihash", unihash, + ], check=True) + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False) + self.assertIsNotNone(result_outhash) + + # Now clean with no minimum age which will remove the outhash + self.run_hashclient([ + "--address", self.server_address, + "clean-unused", "0", + ], check=True) + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False) + self.assertIsNone(result_outhash) + + def test_refresh_token(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", ["@read", "@report"]) + + p = self.run_hashclient([ + "--address", self.auth_server_address, + "--login", user["username"], + "--password", user["token"], + "refresh-token" + ], check=True) + + new_token = None + for l in p.stdout.splitlines(): + l = l.rstrip() + m = re.match(r'Token: +(.*)$', l) + if m is not None: + new_token = m.group(1) + + self.assertTrue(new_token) + + print("New token is %r" % new_token) + + self.run_hashclient([ + "--address", self.auth_server_address, + "--login", user["username"], + "--password", new_token, + "get-user" + ], check=True) + + def test_set_user_perms(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", ["@read"]) + + self.run_hashclient([ + "--address", self.auth_server_address, + "--login", admin_client.username, + "--password", admin_client.password, + "set-user-perms", + "-u", user["username"], + "@read", "@report", + ], check=True) + + new_user = admin_client.get_user(user["username"]) + + self.assertEqual(set(new_user["permissions"]), {"@read", "@report"}) + + def test_get_user(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", ["@read"]) + + p = self.run_hashclient([ + "--address", self.auth_server_address, + "--login", admin_client.username, + "--password", admin_client.password, + "get-user", + "-u", user["username"], + ], check=True) + + self.assertIn("Username:", p.stdout) + self.assertIn("Permissions:", p.stdout) + + p = self.run_hashclient([ + "--address", self.auth_server_address, + "--login", user["username"], + "--password", user["token"], + "get-user", + ], check=True) + + self.assertIn("Username:", p.stdout) + self.assertIn("Permissions:", p.stdout) + + def test_get_all_users(self): + admin_client = self.start_auth_server() + + admin_client.new_user("test-user1", ["@read"]) + admin_client.new_user("test-user2", ["@read"]) + + p = self.run_hashclient([ + "--address", self.auth_server_address, + "--login", admin_client.username, + "--password", admin_client.password, + "get-all-users", + ], check=True) + + self.assertIn("admin", p.stdout) + self.assertIn("test-user1", p.stdout) + self.assertIn("test-user2", p.stdout) + + def test_new_user(self): + admin_client = self.start_auth_server() + + p = self.run_hashclient([ + "--address", self.auth_server_address, + "--login", admin_client.username, + "--password", admin_client.password, + "new-user", + "-u", "test-user", + "@read", "@report", + ], check=True) + + new_token = None + for l in p.stdout.splitlines(): + l = l.rstrip() + m = re.match(r'Token: +(.*)$', l) + if m is not None: + new_token = m.group(1) + + self.assertTrue(new_token) + + user = { + "username": "test-user", + "token": new_token, + } + + self.assertUserPerms(user, ["@read", "@report"]) + + def test_delete_user(self): + admin_client = self.start_auth_server() + + user = admin_client.new_user("test-user", ["@read"]) + + p = self.run_hashclient([ + "--address", self.auth_server_address, + "--login", admin_client.username, + "--password", admin_client.password, + "delete-user", + "-u", user["username"], + ], check=True) + + self.assertIsNone(admin_client.get_user(user["username"])) + + def test_get_db_usage(self): + p = self.run_hashclient([ + "--address", self.server_address, + "get-db-usage", + ], check=True) + + def test_get_db_query_columns(self): + p = self.run_hashclient([ + "--address", self.server_address, + "get-db-query-columns", + ], check=True) + + def test_gc(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Mark the first unihash to be kept + self.run_hashclient([ + "--address", self.server_address, + "gc-mark", "ABC", + "--where", "unihash", unihash, + "--where", "method", self.METHOD + ], check=True) + + # Second hash is still there; mark doesn't delete hashes + self.assertClientGetHash(self.client, taskhash2, unihash2) + + self.run_hashclient([ + "--address", self.server_address, + "gc-sweep", "ABC", + ], check=True) + + # Hash is gone. Taskhash is returned for second hash + self.assertClientGetHash(self.client, taskhash2, None) + # First hash is still present + self.assertClientGetHash(self.client, taskhash, unihash) class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): @@ -313,3 +1437,77 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm # If IPv6 is enabled, it should be safe to use localhost directly, in general # case it is more reliable to resolve the IP address explicitly. return socket.gethostbyname("localhost") + ":0" + + +class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): + def setUp(self): + try: + import websockets + except ImportError as e: + self.skipTest(str(e)) + + super().setUp() + + def get_server_addr(self, server_idx): + # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled. + # If IPv6 is enabled, it should be safe to use localhost directly, in general + # case it is more reliable to resolve the IP address explicitly. + host = socket.gethostbyname("localhost") + return "ws://%s:0" % host + + +class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer): + def setUp(self): + try: + import sqlalchemy + import aiosqlite + except ImportError as e: + self.skipTest(str(e)) + + super().setUp() + + def make_dbpath(self): + return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) + + +class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): + def get_env(self, name): + v = os.environ.get(name) + if not v: + self.skipTest(f'{name} not defined to test an external server') + return v + + def start_test_server(self): + return self.get_env('BB_TEST_HASHSERV') + + def start_server(self, *args, **kwargs): + self.skipTest('Cannot start local server when testing external servers') + + def start_auth_server(self): + + self.auth_server_address = self.server_address + self.admin_client = self.start_client( + self.server_address, + username=self.get_env('BB_TEST_HASHSERV_USERNAME'), + password=self.get_env('BB_TEST_HASHSERV_PASSWORD'), + ) + return self.admin_client + + def setUp(self): + super().setUp() + if "BB_TEST_HASHSERV_USERNAME" in os.environ: + self.client = self.start_client( + self.server_address, + username=os.environ["BB_TEST_HASHSERV_USERNAME"], + password=os.environ["BB_TEST_HASHSERV_PASSWORD"], + ) + self.client.remove({"method": self.METHOD}) + + def tearDown(self): + self.client.remove({"method": self.METHOD}) + super().tearDown() + + + def test_auth_get_all_users(self): + self.skipTest("Cannot test all users with external server") + |