diff options
-rw-r--r-- | lib/hashserv/client.py | 5 | ||||
-rw-r--r-- | lib/hashserv/server.py | 28 | ||||
-rw-r--r-- | lib/hashserv/tests.py | 33 |
3 files changed, 66 insertions, 0 deletions
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py index b2aa1026a..7446e4c9f 100644 --- a/lib/hashserv/client.py +++ b/lib/hashserv/client.py @@ -101,6 +101,10 @@ class AsyncClient(bb.asyncrpc.AsyncClient): await self._set_mode(self.MODE_NORMAL) return (await self.send_message({"backfill-wait": None}))["tasks"] + async def remove(self, where): + await self._set_mode(self.MODE_NORMAL) + return await self.send_message({"remove": {"where": where}}) + class Client(bb.asyncrpc.Client): def __init__(self): @@ -115,6 +119,7 @@ class Client(bb.asyncrpc.Client): "get_stats", "reset_stats", "backfill_wait", + "remove", ) def _get_async_client(self): diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py index d40a2ab8f..daf1ffacb 100644 --- a/lib/hashserv/server.py +++ b/lib/hashserv/server.py @@ -186,6 +186,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): 'report-equiv': self.handle_equivreport, 'reset-stats': self.handle_reset_stats, 'backfill-wait': self.handle_backfill_wait, + 'remove': self.handle_remove, }) def validate_proto_version(self): @@ -499,6 +500,33 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection): await self.backfill_queue.join() self.write_message(d) + async def handle_remove(self, request): + condition = request["where"] + if not isinstance(condition, dict): + raise TypeError("Bad condition type %s" % type(condition)) + + def do_remove(columns, table_name, cursor): + nonlocal condition + where = {} + for c in columns: + if c in condition and condition[c] is not None: + where[c] = condition[c] + + if where: + query = ('DELETE FROM %s WHERE ' % table_name) + ' AND '.join("%s=:%s" % (k, k) for k in where.keys()) + cursor.execute(query, where) + return cursor.rowcount + + return 0 + + count = 0 + with closing(self.db.cursor()) as cursor: + count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor) + count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor) + self.db.commit() + + self.write_message({"count": count}) + def query_equivalent(self, cursor, method, taskhash): # This is part of the inner loop and must be as fast as possible cursor.execute( diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py index f6b85aed8..a3e066406 100644 --- a/lib/hashserv/tests.py +++ b/lib/hashserv/tests.py @@ -84,6 +84,7 @@ class HashEquivalenceCommonTests(object): result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + return taskhash, outhash, unihash def test_create_equivalent(self): # Tests that a second reported task with the same outhash will be @@ -125,6 +126,38 @@ class HashEquivalenceCommonTests(object): self.assertClientGetHash(self.client, taskhash, unihash) + def test_remove_taskhash(self): + taskhash, outhash, unihash = self.test_create_hash() + result = self.client.remove({"taskhash": taskhash}) + self.assertGreater(result["count"], 0) + self.assertClientGetHash(self.client, taskhash, None) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + + def test_remove_unihash(self): + taskhash, outhash, unihash = self.test_create_hash() + result = self.client.remove({"unihash": unihash}) + self.assertGreater(result["count"], 0) + self.assertClientGetHash(self.client, taskhash, None) + + def test_remove_outhash(self): + taskhash, outhash, unihash = self.test_create_hash() + result = self.client.remove({"outhash": outhash}) + self.assertGreater(result["count"], 0) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + + def test_remove_method(self): + taskhash, outhash, unihash = self.test_create_hash() + result = self.client.remove({"method": self.METHOD}) + self.assertGreater(result["count"], 0) + self.assertClientGetHash(self.client, taskhash, None) + + result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) + self.assertIsNone(result_outhash) + def test_huge_message(self): # Simple test that hashes can be created taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' |