aboutsummaryrefslogtreecommitdiffstats
path: root/lib/hashserv
diff options
context:
space:
mode:
Diffstat (limited to 'lib/hashserv')
-rw-r--r--lib/hashserv/__init__.py21
-rw-r--r--lib/hashserv/client.py38
-rw-r--r--lib/hashserv/server.py116
3 files changed, 69 insertions, 106 deletions
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9cb3fd57a..3a4018353 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -15,13 +15,6 @@ UNIX_PREFIX = "unix://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
("taskhash", "TEXT NOT NULL", "UNIQUE"),
@@ -102,20 +95,6 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
-
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
from . import server
db = setup_database(dbname, sync=sync)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index f676d267f..5f7d22ab1 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -28,24 +28,24 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def send_stream(self, msg):
async def proc():
- self.writer.write(("%s\n" % msg).encode("utf-8"))
- await self.writer.drain()
- l = await self.reader.readline()
- if not l:
- raise ConnectionError("Connection closed")
- return l.decode("utf-8").rstrip()
+ await self.socket.send(msg)
+ return await self.socket.recv()
return await self._send_wrapper(proc)
async def _set_mode(self, new_mode):
+ async def stream_to_normal():
+ await self.socket.send("END")
+ return await self.socket.recv()
+
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
- r = await self.send_stream("END")
+ r = await self._send_wrapper(stream_to_normal)
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
- r = await self.send_message({"get-stream": None})
+ r = await self.invoke({"get-stream": None})
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to stream mode: Bad response from server %r" % r)
elif new_mode != self.mode:
raise Exception(
"Undefined mode transition %r -> %r" % (self.mode, new_mode)
@@ -67,7 +67,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["method"] = method
m["outhash"] = outhash
m["unihash"] = unihash
- return await self.send_message({"report": m})
+ return await self.invoke({"report": m})
async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
await self._set_mode(self.MODE_NORMAL)
@@ -75,39 +75,39 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["taskhash"] = taskhash
m["method"] = method
m["unihash"] = unihash
- return await self.send_message({"report-equiv": m})
+ return await self.invoke({"report-equiv": m})
async def get_taskhash(self, method, taskhash, all_properties=False):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
)
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
)
async def get_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"get-stats": None})
+ return await self.invoke({"get-stats": None})
async def reset_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"reset-stats": None})
+ return await self.invoke({"reset-stats": None})
async def backfill_wait(self):
await self._set_mode(self.MODE_NORMAL)
- return (await self.send_message({"backfill-wait": None}))["tasks"]
+ return (await self.invoke({"backfill-wait": None}))["tasks"]
async def remove(self, where):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"remove": {"where": where}})
+ return await self.invoke({"remove": {"where": where}})
async def clean_unused(self, max_age):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"clean-unused": {"max_age_seconds": max_age}})
+ return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
class Client(bb.asyncrpc.Client):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 45bf476bf..13b754805 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -165,8 +165,8 @@ class ServerCursor(object):
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(reader, writer, 'OEHASHEQUIV', logger)
+ def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
+ super().__init__(socket, 'OEHASHEQUIV', logger)
self.db = db
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
@@ -209,12 +209,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in msg:
logger.debug('Handling %s' % k)
if 'stream' in k:
- await self.handlers[k](msg[k])
+ return await self.handlers[k](msg[k])
else:
with self.request_stats.start_sample() as self.request_sample, \
self.request_sample.measure():
- await self.handlers[k](msg[k])
- return
+ return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
@@ -224,9 +223,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
fetch_all = request.get('all', False)
with closing(self.db.cursor()) as cursor:
- d = await self.get_unihash(cursor, method, taskhash, fetch_all)
-
- self.write_message(d)
+ return await self.get_unihash(cursor, method, taskhash, fetch_all)
async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
d = None
@@ -274,9 +271,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
with_unihash = request.get("with_unihash", True)
with closing(self.db.cursor()) as cursor:
- d = await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
-
- self.write_message(d)
+ return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
d = None
@@ -334,14 +329,14 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
async def handle_get_stream(self, request):
- self.write_message('ok')
+ await self.socket.send_message("ok")
while True:
upstream = None
- l = await self.reader.readline()
+ l = await self.socket.recv()
if not l:
- return
+ break
try:
# This inner loop is very sensitive and must be as fast as
@@ -352,10 +347,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- l = l.decode('utf-8').rstrip()
if l == 'END':
- self.writer.write('ok\n'.encode('utf-8'))
- return
+ break
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
@@ -366,29 +359,30 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
cursor.close()
if row is not None:
- msg = ('%s\n' % row['unihash']).encode('utf-8')
+ msg = row['unihash']
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
- msg = ("%s\n" % upstream).encode("utf-8")
+ msg = upstream
else:
- msg = "\n".encode("utf-8")
+ msg = ""
else:
- msg = '\n'.encode('utf-8')
+ msg = ""
- self.writer.write(msg)
+ await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
- await self.writer.drain()
-
# Post to the backfill queue after writing the result to minimize
# the turn around time on a request
if upstream is not None:
await self.backfill_queue.put((method, taskhash))
+ await self.socket.send("ok")
+ return self.NO_RESPONSE
+
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
outhash_data = {
@@ -468,7 +462,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
'unihash': unihash,
}
- self.write_message(d)
+ return d
async def handle_equivreport(self, data):
with closing(self.db.cursor()) as cursor:
@@ -491,30 +485,28 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
- self.write_message(d)
+ return d
async def handle_get_stats(self, request):
- d = {
+ return {
'requests': self.request_stats.todict(),
}
- self.write_message(d)
-
async def handle_reset_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.request_stats.reset()
- self.write_message(d)
+ return d
async def handle_backfill_wait(self, request):
d = {
'tasks': self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
- self.write_message(d)
+ return d
async def handle_remove(self, request):
condition = request["where"]
@@ -541,7 +533,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
self.db.commit()
- self.write_message({"count": count})
+ return {"count": count}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
@@ -558,7 +550,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
count = cursor.rowcount
- self.write_message({"count": count})
+ return {"count": count}
def query_equivalent(self, cursor, method, taskhash):
# This is part of the inner loop and must be as fast as possible
@@ -583,41 +575,33 @@ class Server(bb.asyncrpc.AsyncServer):
self.db = db
self.upstream = upstream
self.read_only = read_only
+ self.backfill_queue = None
- def accept_client(self, reader, writer):
- return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ def accept_client(self, socket):
+ return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
- @contextmanager
- def _backfill_worker(self):
- async def backfill_worker_task():
- client = await create_async_client(self.upstream)
- try:
- while True:
- item = await self.backfill_queue.get()
- if item is None:
- self.backfill_queue.task_done()
- break
- method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ async def backfill_worker_task(self):
+ client = await create_async_client(self.upstream)
+ try:
+ while True:
+ item = await self.backfill_queue.get()
+ if item is None:
self.backfill_queue.task_done()
- finally:
- await client.close()
+ break
+ method, taskhash = item
+ await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ self.backfill_queue.task_done()
+ finally:
+ await client.close()
- async def join_worker(worker):
+ def start(self):
+ tasks = super().start()
+ if self.upstream:
+ self.backfill_queue = asyncio.Queue()
+ tasks += [self.backfill_worker_task()]
+ return tasks
+
+ async def stop(self):
+ if self.backfill_queue is not None:
await self.backfill_queue.put(None)
- await worker
-
- if self.upstream is not None:
- worker = asyncio.ensure_future(backfill_worker_task())
- try:
- yield
- finally:
- self.loop.run_until_complete(join_worker(worker))
- else:
- yield
-
- def run_loop_forever(self):
- self.backfill_queue = asyncio.Queue()
-
- with self._backfill_worker():
- super().run_loop_forever()
+ await super().stop()