summaryrefslogtreecommitdiffstats
path: root/lib/hashserv/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/hashserv/server.py')
-rw-r--r--lib/hashserv/server.py105
1 files changed, 72 insertions, 33 deletions
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index cc7e48233..81050715e 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -13,6 +13,7 @@ import os
import signal
import socket
import time
+from . import chunkify, DEFAULT_MAX_CHUNK
logger = logging.getLogger('hashserv.server')
@@ -107,12 +108,29 @@ class Stats(object):
return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
+class ClientError(Exception):
+ pass
+
class ServerClient(object):
+ FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
+ ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
+
def __init__(self, reader, writer, db, request_stats):
self.reader = reader
self.writer = writer
self.db = db
self.request_stats = request_stats
+ self.max_chunk = DEFAULT_MAX_CHUNK
+
+ self.handlers = {
+ 'get': self.handle_get,
+ 'report': self.handle_report,
+ 'report-equiv': self.handle_equivreport,
+ 'get-stream': self.handle_get_stream,
+ 'get-stats': self.handle_get_stats,
+ 'reset-stats': self.handle_reset_stats,
+ 'chunk-stream': self.handle_chunk,
+ }
async def process_requests(self):
try:
@@ -125,7 +143,11 @@ class ServerClient(object):
return
(proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
- if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
+ if proto_name != 'OEHASHEQUIV':
+ return
+
+ proto_version = tuple(int(v) for v in proto_version.split('.'))
+ if proto_version < (1, 0) or proto_version > (1, 1):
return
# Read headers. Currently, no headers are implemented, so look for
@@ -140,40 +162,34 @@ class ServerClient(object):
break
# Handle messages
- handlers = {
- 'get': self.handle_get,
- 'report': self.handle_report,
- 'report-equiv': self.handle_equivreport,
- 'get-stream': self.handle_get_stream,
- 'get-stats': self.handle_get_stats,
- 'reset-stats': self.handle_reset_stats,
- }
-
while True:
d = await self.read_message()
if d is None:
break
-
- for k in handlers.keys():
- if k in d:
- logger.debug('Handling %s' % k)
- if 'stream' in k:
- await handlers[k](d[k])
- else:
- with self.request_stats.start_sample() as self.request_sample, \
- self.request_sample.measure():
- await handlers[k](d[k])
- break
- else:
- logger.warning("Unrecognized command %r" % d)
- break
-
+ await self.dispatch_message(d)
await self.writer.drain()
+ except ClientError as e:
+ logger.error(str(e))
finally:
self.writer.close()
+ async def dispatch_message(self, msg):
+ for k in self.handlers.keys():
+ if k in msg:
+ logger.debug('Handling %s' % k)
+ if 'stream' in k:
+ await self.handlers[k](msg[k])
+ else:
+ with self.request_stats.start_sample() as self.request_sample, \
+ self.request_sample.measure():
+ await self.handlers[k](msg[k])
+ return
+
+ raise ClientError("Unrecognized command %r" % msg)
+
def write_message(self, msg):
- self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
+ for c in chunkify(json.dumps(msg), self.max_chunk):
+ self.writer.write(c.encode('utf-8'))
async def read_message(self):
l = await self.reader.readline()
@@ -191,14 +207,38 @@ class ServerClient(object):
logger.error('Bad message from client: %r' % message)
raise e
+ async def handle_chunk(self, request):
+ lines = []
+ try:
+ while True:
+ l = await self.reader.readline()
+ l = l.rstrip(b"\n").decode("utf-8")
+ if not l:
+ break
+ lines.append(l)
+
+ msg = json.loads(''.join(lines))
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
+ logger.error('Bad message from client: %r' % message)
+ raise e
+
+ if 'chunk-stream' in msg:
+ raise ClientError("Nested chunks are not allowed")
+
+ await self.dispatch_message(msg)
+
async def handle_get(self, request):
method = request['method']
taskhash = request['taskhash']
- row = self.query_equivalent(method, taskhash)
+ if request.get('all', False):
+ row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
+ else:
+ row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
+
if row is not None:
logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
+ d = {k: row[k] for k in row.keys()}
self.write_message(d)
else:
@@ -228,7 +268,7 @@ class ServerClient(object):
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
- row = self.query_equivalent(method, taskhash)
+ row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
if row is not None:
msg = ('%s\n' % row['unihash']).encode('utf-8')
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
@@ -328,7 +368,7 @@ class ServerClient(object):
# Fetch the unihash that will be reported for the taskhash. If the
# unihash matches, it means this row was inserted (or the mapping
# was already valid)
- row = self.query_equivalent(data['method'], data['taskhash'])
+ row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
if row['unihash'] == data['unihash']:
logger.info('Adding taskhash equivalence for %s with unihash %s',
@@ -354,12 +394,11 @@ class ServerClient(object):
self.request_stats.reset()
self.write_message(d)
- def query_equivalent(self, method, taskhash):
+ def query_equivalent(self, method, taskhash, query):
# This is part of the inner loop and must be as fast as possible
try:
cursor = self.db.cursor()
- cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
- {'method': method, 'taskhash': taskhash})
+ cursor.execute(query, {'method': method, 'taskhash': taskhash})
return cursor.fetchone()
except:
cursor.close()