diff options
author | Joshua Watt <JPEWhacker@gmail.com> | 2024-02-18 15:59:46 -0700 |
---|---|---|
committer | Richard Purdie <richard.purdie@linuxfoundation.org> | 2024-02-19 11:53:15 +0000 |
commit | 433d4a075a1acfbd2a2913061739353a84bb01ed (patch) | |
tree | c8f884b95594b013eb84df644b6eebbccf826d53 /lib/hashserv | |
parent | df184b2a4e80fca847cfe90644110b74a1af613e (diff) | |
download | bitbake-433d4a075a1acfbd2a2913061739353a84bb01ed.tar.gz |
hashserv: Add Unihash Garbage Collection
Adds support for removing unused unihashes from the database. This is
done using a "mark and sweep" style of garbage collection where a
collection is started by marking which unihashes should be kept in the
database, then performing a sweep to remove any unmarked hashes.
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'lib/hashserv')
-rw-r--r-- | lib/hashserv/client.py | 31 | ||||
-rw-r--r-- | lib/hashserv/server.py | 105 | ||||
-rw-r--r-- | lib/hashserv/sqlalchemy.py | 226 | ||||
-rw-r--r-- | lib/hashserv/sqlite.py | 205 | ||||
-rw-r--r-- | lib/hashserv/tests.py | 198 |
5 files changed, 649 insertions, 116 deletions
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py index 35a97687f..e6dc41791 100644 --- a/lib/hashserv/client.py +++ b/lib/hashserv/client.py @@ -194,6 +194,34 @@ class AsyncClient(bb.asyncrpc.AsyncClient): await self._set_mode(self.MODE_NORMAL) return (await self.invoke({"get-db-query-columns": {}}))["columns"] + async def gc_status(self): + await self._set_mode(self.MODE_NORMAL) + return await self.invoke({"gc-status": {}}) + + async def gc_mark(self, mark, where): + """ + Starts a new garbage collection operation identified by "mark". If + garbage collection is already in progress with "mark", the collection + is continued. + + All unihash entries that match the "where" clause are marked to be + kept. In addition, any new entries added to the database after this + command will be automatically marked with "mark" + """ + await self._set_mode(self.MODE_NORMAL) + return await self.invoke({"gc-mark": {"mark": mark, "where": where}}) + + async def gc_sweep(self, mark): + """ + Finishes garbage collection for "mark". All unihash entries that have + not been marked will be deleted. + + It is recommended to clean unused outhash entries after running this to + cleanup any dangling outhashes + """ + await self._set_mode(self.MODE_NORMAL) + return await self.invoke({"gc-sweep": {"mark": mark}}) + class Client(bb.asyncrpc.Client): def __init__(self, username=None, password=None): @@ -224,6 +252,9 @@ class Client(bb.asyncrpc.Client): "become_user", "get_db_usage", "get_db_query_columns", + "gc_status", + "gc_mark", + "gc_sweep", ) def _get_async_client(self): diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py index a86507830..5ed852d1f 100644 --- a/lib/hashserv/server.py +++ b/lib/hashserv/server.py @@ -199,7 +199,7 @@ def permissions(*permissions, allow_anon=True, allow_self_service=False): if not self.user_has_permissions(*permissions, allow_anon=allow_anon): if not self.user: username = "Anonymous user" - user_perms = self.anon_perms + user_perms = self.server.anon_perms else: username = self.user.username user_perms = self.user.permissions @@ -223,25 +223,11 @@ def permissions(*permissions, allow_anon=True, allow_self_service=False): class ServerClient(bb.asyncrpc.AsyncServerConnection): - def __init__( - self, - socket, - db_engine, - request_stats, - backfill_queue, - upstream, - read_only, - anon_perms, - ): - super().__init__(socket, "OEHASHEQUIV", logger) - self.db_engine = db_engine - self.request_stats = request_stats + def __init__(self, socket, server): + super().__init__(socket, "OEHASHEQUIV", server.logger) + self.server = server self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK - self.backfill_queue = backfill_queue - self.upstream = upstream - self.read_only = read_only self.user = None - self.anon_perms = anon_perms self.handlers.update( { @@ -261,13 +247,16 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): } ) - if not read_only: + if not self.server.read_only: self.handlers.update( { "report-equiv": self.handle_equivreport, "reset-stats": self.handle_reset_stats, "backfill-wait": self.handle_backfill_wait, "remove": self.handle_remove, + "gc-mark": self.handle_gc_mark, + "gc-sweep": self.handle_gc_sweep, + "gc-status": self.handle_gc_status, "clean-unused": self.handle_clean_unused, "refresh-token": self.handle_refresh_token, "set-user-perms": self.handle_set_perms, @@ -282,10 +271,10 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): def user_has_permissions(self, *permissions, allow_anon=True): permissions = set(permissions) if allow_anon: - if ALL_PERM in self.anon_perms: + if ALL_PERM in self.server.anon_perms: return True - if not permissions - self.anon_perms: + if not permissions - self.server.anon_perms: return True if self.user is None: @@ -303,10 +292,10 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): return self.proto_version > (1, 0) and self.proto_version <= (1, 1) async def process_requests(self): - async with self.db_engine.connect(self.logger) as db: + async with self.server.db_engine.connect(self.logger) as db: self.db = db - if self.upstream is not None: - self.upstream_client = await create_async_client(self.upstream) + if self.server.upstream is not None: + self.upstream_client = await create_async_client(self.server.upstream) else: self.upstream_client = None @@ -323,7 +312,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): if "stream" in k: return await self.handlers[k](msg[k]) else: - with self.request_stats.start_sample() as self.request_sample, self.request_sample.measure(): + with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure(): return await self.handlers[k](msg[k]) raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg) @@ -404,7 +393,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): # possible (which is why the request sample is handled manually # instead of using 'with', and also why logging statements are # commented out. - self.request_sample = self.request_stats.start_sample() + self.request_sample = self.server.request_stats.start_sample() request_measure = self.request_sample.measure() request_measure.start() @@ -435,7 +424,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): # Post to the backfill queue after writing the result to minimize # the turn around time on a request if upstream is not None: - await self.backfill_queue.put((method, taskhash)) + await self.server.backfill_queue.put((method, taskhash)) await self.socket.send("ok") return self.NO_RESPONSE @@ -461,7 +450,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): # report is made inside the function @permissions(READ_PERM) async def handle_report(self, data): - if self.read_only or not self.user_has_permissions(REPORT_PERM): + if self.server.read_only or not self.user_has_permissions(REPORT_PERM): return await self.report_readonly(data) outhash_data = { @@ -538,24 +527,24 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): @permissions(READ_PERM) async def handle_get_stats(self, request): return { - "requests": self.request_stats.todict(), + "requests": self.server.request_stats.todict(), } @permissions(DB_ADMIN_PERM) async def handle_reset_stats(self, request): d = { - "requests": self.request_stats.todict(), + "requests": self.server.request_stats.todict(), } - self.request_stats.reset() + self.server.request_stats.reset() return d @permissions(READ_PERM) async def handle_backfill_wait(self, request): d = { - "tasks": self.backfill_queue.qsize(), + "tasks": self.server.backfill_queue.qsize(), } - await self.backfill_queue.join() + await self.server.backfill_queue.join() return d @permissions(DB_ADMIN_PERM) @@ -567,6 +556,46 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): return {"count": await self.db.remove(condition)} @permissions(DB_ADMIN_PERM) + async def handle_gc_mark(self, request): + condition = request["where"] + mark = request["mark"] + + if not isinstance(condition, dict): + raise TypeError("Bad condition type %s" % type(condition)) + + if not isinstance(mark, str): + raise TypeError("Bad mark type %s" % type(mark)) + + return {"count": await self.db.gc_mark(mark, condition)} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_sweep(self, request): + mark = request["mark"] + + if not isinstance(mark, str): + raise TypeError("Bad mark type %s" % type(mark)) + + current_mark = await self.db.get_current_gc_mark() + + if not current_mark or mark != current_mark: + raise bb.asyncrpc.InvokeError( + f"'{mark}' is not the current mark. Refusing to sweep" + ) + + count = await self.db.gc_sweep() + + return {"count": count} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_status(self, request): + (keep_rows, remove_rows, current_mark) = await self.db.gc_status() + return { + "keep": keep_rows, + "remove": remove_rows, + "mark": current_mark, + } + + @permissions(DB_ADMIN_PERM) async def handle_clean_unused(self, request): max_age = request["max_age_seconds"] oldest = datetime.now() - timedelta(seconds=-max_age) @@ -779,15 +808,7 @@ class Server(bb.asyncrpc.AsyncServer): ) def accept_client(self, socket): - return ServerClient( - socket, - self.db_engine, - self.request_stats, - self.backfill_queue, - self.upstream, - self.read_only, - self.anon_perms, - ) + return ServerClient(socket, self) async def create_admin_user(self): admin_permissions = (ALL_PERM,) diff --git a/lib/hashserv/sqlalchemy.py b/lib/hashserv/sqlalchemy.py index cee04bffb..89a6b86d9 100644 --- a/lib/hashserv/sqlalchemy.py +++ b/lib/hashserv/sqlalchemy.py @@ -28,6 +28,7 @@ from sqlalchemy import ( delete, update, func, + inspect, ) import sqlalchemy.engine from sqlalchemy.orm import declarative_base @@ -36,16 +37,17 @@ from sqlalchemy.exc import IntegrityError Base = declarative_base() -class UnihashesV2(Base): - __tablename__ = "unihashes_v2" +class UnihashesV3(Base): + __tablename__ = "unihashes_v3" id = Column(Integer, primary_key=True, autoincrement=True) method = Column(Text, nullable=False) taskhash = Column(Text, nullable=False) unihash = Column(Text, nullable=False) + gc_mark = Column(Text, nullable=False) __table_args__ = ( UniqueConstraint("method", "taskhash"), - Index("taskhash_lookup_v3", "method", "taskhash"), + Index("taskhash_lookup_v4", "method", "taskhash"), ) @@ -79,6 +81,36 @@ class Users(Base): __table_args__ = (UniqueConstraint("username"),) +class Config(Base): + __tablename__ = "config" + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(Text, nullable=False) + value = Column(Text) + __table_args__ = ( + UniqueConstraint("name"), + Index("config_lookup", "name"), + ) + + +# +# Old table versions +# +DeprecatedBase = declarative_base() + + +class UnihashesV2(DeprecatedBase): + __tablename__ = "unihashes_v2" + id = Column(Integer, primary_key=True, autoincrement=True) + method = Column(Text, nullable=False) + taskhash = Column(Text, nullable=False) + unihash = Column(Text, nullable=False) + + __table_args__ = ( + UniqueConstraint("method", "taskhash"), + Index("taskhash_lookup_v3", "method", "taskhash"), + ) + + class DatabaseEngine(object): def __init__(self, url, username=None, password=None): self.logger = logging.getLogger("hashserv.sqlalchemy") @@ -91,6 +123,9 @@ class DatabaseEngine(object): self.url = self.url.set(password=password) async def create(self): + def check_table_exists(conn, name): + return inspect(conn).has_table(name) + self.logger.info("Using database %s", self.url) self.engine = create_async_engine(self.url, poolclass=NullPool) @@ -99,6 +134,24 @@ class DatabaseEngine(object): self.logger.info("Creating tables...") await conn.run_sync(Base.metadata.create_all) + if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__): + self.logger.info("Upgrading Unihashes V2 -> V3...") + statement = insert(UnihashesV3).from_select( + ["id", "method", "unihash", "taskhash", "gc_mark"], + select( + UnihashesV2.id, + UnihashesV2.method, + UnihashesV2.unihash, + UnihashesV2.taskhash, + literal("").label("gc_mark"), + ), + ) + self.logger.debug("%s", statement) + await conn.execute(statement) + + await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__]) + self.logger.info("Upgrade complete") + def connect(self, logger): return Database(self.engine, logger) @@ -118,6 +171,15 @@ def map_user(row): ) +def _make_condition_statement(table, condition): + where = {} + for c in table.__table__.columns: + if c.key in condition and condition[c.key] is not None: + where[c] = condition[c.key] + + return [(k == v) for k, v in where.items()] + + class Database(object): def __init__(self, engine, logger): self.engine = engine @@ -135,17 +197,52 @@ class Database(object): await self.db.close() self.db = None + async def _execute(self, statement): + self.logger.debug("%s", statement) + return await self.db.execute(statement) + + async def _set_config(self, name, value): + while True: + result = await self._execute( + update(Config).where(Config.name == name).values(value=value) + ) + + if result.rowcount == 0: + self.logger.debug("Config '%s' not found. Adding it", name) + try: + await self._execute(insert(Config).values(name=name, value=value)) + except IntegrityError: + # Race. Try again + continue + + break + + def _get_config_subquery(self, name, default=None): + if default is not None: + return func.coalesce( + select(Config.value).where(Config.name == name).scalar_subquery(), + default, + ) + return select(Config.value).where(Config.name == name).scalar_subquery() + + async def _get_config(self, name): + result = await self._execute(select(Config.value).where(Config.name == name)) + row = result.first() + if row is None: + return None + return row.value + async def get_unihash_by_taskhash_full(self, method, taskhash): statement = ( select( OuthashesV2, - UnihashesV2.unihash.label("unihash"), + UnihashesV3.unihash.label("unihash"), ) .join( - UnihashesV2, + UnihashesV3, and_( - UnihashesV2.method == OuthashesV2.method, - UnihashesV2.taskhash == OuthashesV2.taskhash, + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, ), ) .where( @@ -164,12 +261,12 @@ class Database(object): async def get_unihash_by_outhash(self, method, outhash): statement = ( - select(OuthashesV2, UnihashesV2.unihash.label("unihash")) + select(OuthashesV2, UnihashesV3.unihash.label("unihash")) .join( - UnihashesV2, + UnihashesV3, and_( - UnihashesV2.method == OuthashesV2.method, - UnihashesV2.taskhash == OuthashesV2.taskhash, + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, ), ) .where( @@ -208,13 +305,13 @@ class Database(object): statement = ( select( OuthashesV2.taskhash.label("taskhash"), - UnihashesV2.unihash.label("unihash"), + UnihashesV3.unihash.label("unihash"), ) .join( - UnihashesV2, + UnihashesV3, and_( - UnihashesV2.method == OuthashesV2.method, - UnihashesV2.taskhash == OuthashesV2.taskhash, + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, ), ) .where( @@ -234,12 +331,12 @@ class Database(object): async def get_equivalent(self, method, taskhash): statement = select( - UnihashesV2.unihash, - UnihashesV2.method, - UnihashesV2.taskhash, + UnihashesV3.unihash, + UnihashesV3.method, + UnihashesV3.taskhash, ).where( - UnihashesV2.method == method, - UnihashesV2.taskhash == taskhash, + UnihashesV3.method == method, + UnihashesV3.taskhash == taskhash, ) self.logger.debug("%s", statement) async with self.db.begin(): @@ -248,13 +345,9 @@ class Database(object): async def remove(self, condition): async def do_remove(table): - where = {} - for c in table.__table__.columns: - if c.key in condition and condition[c.key] is not None: - where[c] = condition[c.key] - + where = _make_condition_statement(table, condition) if where: - statement = delete(table).where(*[(k == v) for k, v in where.items()]) + statement = delete(table).where(*where) self.logger.debug("%s", statement) async with self.db.begin(): result = await self.db.execute(statement) @@ -263,19 +356,74 @@ class Database(object): return 0 count = 0 - count += await do_remove(UnihashesV2) + count += await do_remove(UnihashesV3) count += await do_remove(OuthashesV2) return count + async def get_current_gc_mark(self): + async with self.db.begin(): + return await self._get_config("gc-mark") + + async def gc_status(self): + async with self.db.begin(): + gc_mark_subquery = self._get_config_subquery("gc-mark", "") + + result = await self._execute( + select(func.count()) + .select_from(UnihashesV3) + .where(UnihashesV3.gc_mark == gc_mark_subquery) + ) + keep_rows = result.scalar() + + result = await self._execute( + select(func.count()) + .select_from(UnihashesV3) + .where(UnihashesV3.gc_mark != gc_mark_subquery) + ) + remove_rows = result.scalar() + + return (keep_rows, remove_rows, await self._get_config("gc-mark")) + + async def gc_mark(self, mark, condition): + async with self.db.begin(): + await self._set_config("gc-mark", mark) + + where = _make_condition_statement(UnihashesV3, condition) + if not where: + return 0 + + result = await self._execute( + update(UnihashesV3) + .values(gc_mark=self._get_config_subquery("gc-mark", "")) + .where(*where) + ) + return result.rowcount + + async def gc_sweep(self): + async with self.db.begin(): + result = await self._execute( + delete(UnihashesV3).where( + # A sneaky conditional that provides some errant use + # protection: If the config mark is NULL, this will not + # match any rows because No default is specified in the + # select statement + UnihashesV3.gc_mark + != self._get_config_subquery("gc-mark") + ) + ) + await self._set_config("gc-mark", None) + + return result.rowcount + async def clean_unused(self, oldest): statement = delete(OuthashesV2).where( OuthashesV2.created < oldest, ~( - select(UnihashesV2.id) + select(UnihashesV3.id) .where( - UnihashesV2.method == OuthashesV2.method, - UnihashesV2.taskhash == OuthashesV2.taskhash, + UnihashesV3.method == OuthashesV2.method, + UnihashesV3.taskhash == OuthashesV2.taskhash, ) .limit(1) .exists() @@ -287,15 +435,17 @@ class Database(object): return result.rowcount async def insert_unihash(self, method, taskhash, unihash): - statement = insert(UnihashesV2).values( - method=method, - taskhash=taskhash, - unihash=unihash, - ) - self.logger.debug("%s", statement) try: async with self.db.begin(): - await self.db.execute(statement) + await self._execute( + insert(UnihashesV3).values( + method=method, + taskhash=taskhash, + unihash=unihash, + gc_mark=self._get_config_subquery("gc-mark", ""), + ) + ) + return True except IntegrityError: self.logger.debug( @@ -418,7 +568,7 @@ class Database(object): async def get_query_columns(self): columns = set() - for table in (UnihashesV2, OuthashesV2): + for table in (UnihashesV3, OuthashesV2): for c in table.__table__.columns: if not isinstance(c.type, Text): continue diff --git a/lib/hashserv/sqlite.py b/lib/hashserv/sqlite.py index f93cb2c1d..608490730 100644 --- a/lib/hashserv/sqlite.py +++ b/lib/hashserv/sqlite.py @@ -15,6 +15,7 @@ UNIHASH_TABLE_DEFINITION = ( ("method", "TEXT NOT NULL", "UNIQUE"), ("taskhash", "TEXT NOT NULL", "UNIQUE"), ("unihash", "TEXT NOT NULL", ""), + ("gc_mark", "TEXT NOT NULL", ""), ) UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION) @@ -44,6 +45,14 @@ USERS_TABLE_DEFINITION = ( USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION) +CONFIG_TABLE_DEFINITION = ( + ("name", "TEXT NOT NULL", "UNIQUE"), + ("value", "TEXT", ""), +) + +CONFIG_TABLE_COLUMNS = tuple(name for name, _, _ in CONFIG_TABLE_DEFINITION) + + def _make_table(cursor, name, definition): cursor.execute( """ @@ -71,6 +80,35 @@ def map_user(row): ) +def _make_condition_statement(columns, condition): + where = {} + for c in columns: + if c in condition and condition[c] is not None: + where[c] = condition[c] + + return where, " AND ".join("%s=:%s" % (k, k) for k in where.keys()) + + +def _get_sqlite_version(cursor): + cursor.execute("SELECT sqlite_version()") + + version = [] + for v in cursor.fetchone()[0].split("."): + try: + version.append(int(v)) + except ValueError: + version.append(v) + + return tuple(version) + + +def _schema_table_name(version): + if version >= (3, 33): + return "sqlite_schema" + + return "sqlite_master" + + class DatabaseEngine(object): def __init__(self, dbname, sync): self.dbname = dbname @@ -82,9 +120,10 @@ class DatabaseEngine(object): db.row_factory = sqlite3.Row with closing(db.cursor()) as cursor: - _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION) + _make_table(cursor, "unihashes_v3", UNIHASH_TABLE_DEFINITION) _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION) _make_table(cursor, "users", USERS_TABLE_DEFINITION) + _make_table(cursor, "config", CONFIG_TABLE_DEFINITION) cursor.execute("PRAGMA journal_mode = WAL") cursor.execute( @@ -96,17 +135,38 @@ class DatabaseEngine(object): cursor.execute("DROP INDEX IF EXISTS outhash_lookup") cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2") cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2") + cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v3") # TODO: Upgrade from tasks_v2? cursor.execute("DROP TABLE IF EXISTS tasks_v2") # Create new indexes cursor.execute( - "CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)" + "CREATE INDEX IF NOT EXISTS taskhash_lookup_v4 ON unihashes_v3 (method, taskhash)" ) cursor.execute( "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)" ) + cursor.execute("CREATE INDEX IF NOT EXISTS config_lookup ON config (name)") + + sqlite_version = _get_sqlite_version(cursor) + + cursor.execute( + f""" + SELECT name FROM {_schema_table_name(sqlite_version)} WHERE type = 'table' AND name = 'unihashes_v2' + """ + ) + if cursor.fetchone(): + self.logger.info("Upgrading Unihashes V2 -> V3...") + cursor.execute( + """ + INSERT INTO unihashes_v3 (id, method, unihash, taskhash, gc_mark) + SELECT id, method, unihash, taskhash, '' FROM unihashes_v2 + """ + ) + cursor.execute("DROP TABLE unihashes_v2") + db.commit() + self.logger.info("Upgrade complete") def connect(self, logger): return Database(logger, self.dbname, self.sync) @@ -126,16 +186,7 @@ class Database(object): "PRAGMA synchronous = %s" % ("NORMAL" if sync else "OFF") ) - cursor.execute("SELECT sqlite_version()") - - version = [] - for v in cursor.fetchone()[0].split("."): - try: - version.append(int(v)) - except ValueError: - version.append(v) - - self.sqlite_version = tuple(version) + self.sqlite_version = _get_sqlite_version(cursor) async def __aenter__(self): return self @@ -143,6 +194,30 @@ class Database(object): async def __aexit__(self, exc_type, exc_value, traceback): await self.close() + async def _set_config(self, cursor, name, value): + cursor.execute( + """ + INSERT OR REPLACE INTO config (id, name, value) VALUES + ((SELECT id FROM config WHERE name=:name), :name, :value) + """, + { + "name": name, + "value": value, + }, + ) + + async def _get_config(self, cursor, name): + cursor.execute( + "SELECT value FROM config WHERE name=:name", + { + "name": name, + }, + ) + row = cursor.fetchone() + if row is None: + return None + return row["value"] + async def close(self): self.db.close() @@ -150,8 +225,8 @@ class Database(object): with closing(self.db.cursor()) as cursor: cursor.execute( """ - SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2 - INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash + SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2 + INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash ORDER BY outhashes_v2.created ASC LIMIT 1 @@ -167,8 +242,8 @@ class Database(object): with closing(self.db.cursor()) as cursor: cursor.execute( """ - SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2 - INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash + SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2 + INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash ORDER BY outhashes_v2.created ASC LIMIT 1 @@ -200,8 +275,8 @@ class Database(object): with closing(self.db.cursor()) as cursor: cursor.execute( """ - SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2 - INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash + SELECT outhashes_v2.taskhash AS taskhash, unihashes_v3.unihash AS unihash FROM outhashes_v2 + INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash -- Select any matching output hash except the one we just inserted WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash -- Pick the oldest hash @@ -219,7 +294,7 @@ class Database(object): async def get_equivalent(self, method, taskhash): with closing(self.db.cursor()) as cursor: cursor.execute( - "SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash", + "SELECT taskhash, method, unihash FROM unihashes_v3 WHERE method=:method AND taskhash=:taskhash", { "method": method, "taskhash": taskhash, @@ -229,15 +304,9 @@ class Database(object): async def remove(self, condition): def do_remove(columns, table_name, cursor): - where = {} - for c in columns: - if c in condition and condition[c] is not None: - where[c] = condition[c] - + where, clause = _make_condition_statement(columns, condition) if where: - query = ("DELETE FROM %s WHERE " % table_name) + " AND ".join( - "%s=:%s" % (k, k) for k in where.keys() - ) + query = f"DELETE FROM {table_name} WHERE {clause}" cursor.execute(query, where) return cursor.rowcount @@ -246,17 +315,80 @@ class Database(object): count = 0 with closing(self.db.cursor()) as cursor: count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor) - count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor) + count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v3", cursor) self.db.commit() return count + async def get_current_gc_mark(self): + with closing(self.db.cursor()) as cursor: + return await self._get_config(cursor, "gc-mark") + + async def gc_status(self): + with closing(self.db.cursor()) as cursor: + cursor.execute( + """ + SELECT COUNT() FROM unihashes_v3 WHERE + gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') + """ + ) + keep_rows = cursor.fetchone()[0] + + cursor.execute( + """ + SELECT COUNT() FROM unihashes_v3 WHERE + gc_mark!=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') + """ + ) + remove_rows = cursor.fetchone()[0] + + current_mark = await self._get_config(cursor, "gc-mark") + + return (keep_rows, remove_rows, current_mark) + + async def gc_mark(self, mark, condition): + with closing(self.db.cursor()) as cursor: + await self._set_config(cursor, "gc-mark", mark) + + where, clause = _make_condition_statement(UNIHASH_TABLE_COLUMNS, condition) + + new_rows = 0 + if where: + cursor.execute( + f""" + UPDATE unihashes_v3 SET + gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') + WHERE {clause} + """, + where, + ) + new_rows = cursor.rowcount + + self.db.commit() + return new_rows + + async def gc_sweep(self): + with closing(self.db.cursor()) as cursor: + # NOTE: COALESCE is not used in this query so that if the current + # mark is NULL, nothing will happen + cursor.execute( + """ + DELETE FROM unihashes_v3 WHERE + gc_mark!=(SELECT value FROM config WHERE name='gc-mark') + """ + ) + count = cursor.rowcount + await self._set_config(cursor, "gc-mark", None) + + self.db.commit() + return count + async def clean_unused(self, oldest): with closing(self.db.cursor()) as cursor: cursor.execute( """ DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS ( - SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1 + SELECT unihashes_v3.id FROM unihashes_v3 WHERE unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash LIMIT 1 ) """, { @@ -271,7 +403,13 @@ class Database(object): prevrowid = cursor.lastrowid cursor.execute( """ - INSERT OR IGNORE INTO unihashes_v2 (method, taskhash, unihash) VALUES(:method, :taskhash, :unihash) + INSERT OR IGNORE INTO unihashes_v3 (method, taskhash, unihash, gc_mark) VALUES + ( + :method, + :taskhash, + :unihash, + COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') + ) """, { "method": method, @@ -383,14 +521,9 @@ class Database(object): async def get_usage(self): usage = {} with closing(self.db.cursor()) as cursor: - if self.sqlite_version >= (3, 33): - table_name = "sqlite_schema" - else: - table_name = "sqlite_master" - cursor.execute( f""" - SELECT name FROM {table_name} WHERE type = 'table' AND name NOT LIKE 'sqlite_%' + SELECT name FROM {_schema_table_name(self.sqlite_version)} WHERE type = 'table' AND name NOT LIKE 'sqlite_%' """ ) for row in cursor.fetchall(): diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py index 869f7636c..aeedab357 100644 --- a/lib/hashserv/tests.py +++ b/lib/hashserv/tests.py @@ -810,6 +810,27 @@ class HashEquivalenceCommonTests(object): with self.auth_perms("@user-admin") as client: become = client.become_user(client.username) + def test_auth_gc(self): + admin_client = self.start_auth_server() + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.gc_mark("ABC", {"unihash": "123"}) + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.gc_status() + + with self.auth_perms() as client, self.assertRaises(InvokeError): + client.gc_sweep("ABC") + + with self.auth_perms("@db-admin") as client: + client.gc_mark("ABC", {"unihash": "123"}) + + with self.auth_perms("@db-admin") as client: + client.gc_status() + + with self.auth_perms("@db-admin") as client: + client.gc_sweep("ABC") + def test_get_db_usage(self): usage = self.client.get_db_usage() @@ -837,6 +858,147 @@ class HashEquivalenceCommonTests(object): data = client.get_taskhash(self.METHOD, taskhash, True) self.assertEqual(data["owner"], user["username"]) + def test_gc(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Mark the first unihash to be kept + ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) + + # Second hash is still there; mark doesn't delete hashes + self.assertClientGetHash(self.client, taskhash2, unihash2) + + ret = self.client.gc_sweep("ABC") + self.assertEqual(ret, {"count": 1}) + + # Hash is gone. Taskhash is returned for second hash + self.assertClientGetHash(self.client, taskhash2, None) + # First hash is still present + self.assertClientGetHash(self.client, taskhash, unihash) + + def test_gc_switch_mark(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Mark the first unihash to be kept + ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) + + # Second hash is still there; mark doesn't delete hashes + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Switch to a different mark and mark the second hash. This will start + # a new collection cycle + ret = self.client.gc_mark("DEF", {"unihash": unihash2, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "DEF", "keep": 1, "remove": 1}) + + # Both hashes are still present + self.assertClientGetHash(self.client, taskhash2, unihash2) + self.assertClientGetHash(self.client, taskhash, unihash) + + # Sweep with the new mark + ret = self.client.gc_sweep("DEF") + self.assertEqual(ret, {"count": 1}) + + # First hash is gone, second is kept + self.assertClientGetHash(self.client, taskhash2, unihash2) + self.assertClientGetHash(self.client, taskhash, None) + + def test_gc_switch_sweep_mark(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Mark the first unihash to be kept + ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) + + # Sweeping with a different mark raises an error + with self.assertRaises(InvokeError): + self.client.gc_sweep("DEF") + + # Both hashes are present + self.assertClientGetHash(self.client, taskhash2, unihash2) + self.assertClientGetHash(self.client, taskhash, unihash) + + def test_gc_new_hashes(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + # Start a new garbage collection + ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) + self.assertEqual(ret, {"count": 1}) + + ret = self.client.gc_status() + self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 0}) + + # Add second hash. It should inherit the mark from the current garbage + # collection operation + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Sweep should remove nothing + ret = self.client.gc_sweep("ABC") + self.assertEqual(ret, {"count": 0}) + + # Both hashes are present + self.assertClientGetHash(self.client, taskhash2, unihash2) + self.assertClientGetHash(self.client, taskhash, unihash) + class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase): def get_server_addr(self, server_idx): @@ -1086,6 +1248,42 @@ class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase): "get-db-query-columns", ], check=True) + def test_gc(self): + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + + result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + self.assertClientGetHash(self.client, taskhash2, unihash2) + + # Mark the first unihash to be kept + self.run_hashclient([ + "--address", self.server_address, + "gc-mark", "ABC", + "--where", "unihash", unihash, + "--where", "method", self.METHOD + ], check=True) + + # Second hash is still there; mark doesn't delete hashes + self.assertClientGetHash(self.client, taskhash2, unihash2) + + self.run_hashclient([ + "--address", self.server_address, + "gc-sweep", "ABC", + ], check=True) + + # Hash is gone. Taskhash is returned for second hash + self.assertClientGetHash(self.client, taskhash2, None) + # First hash is still present + self.assertClientGetHash(self.client, taskhash, unihash) + class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): def get_server_addr(self, server_idx): |