aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/bb/asyncrpc/__init__.py32
-rw-r--r--lib/bb/asyncrpc/client.py78
-rw-r--r--lib/bb/asyncrpc/connection.py95
-rw-r--r--lib/bb/asyncrpc/exceptions.py17
-rw-r--r--lib/bb/asyncrpc/serv.py304
-rw-r--r--lib/hashserv/__init__.py21
-rw-r--r--lib/hashserv/client.py38
-rw-r--r--lib/hashserv/server.py116
-rw-r--r--lib/prserv/client.py8
-rw-r--r--lib/prserv/serv.py31
10 files changed, 387 insertions, 353 deletions
diff --git a/lib/bb/asyncrpc/__init__.py b/lib/bb/asyncrpc/__init__.py
index 9a85e9965..9f677eac4 100644
--- a/lib/bb/asyncrpc/__init__.py
+++ b/lib/bb/asyncrpc/__init__.py
@@ -4,30 +4,12 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-import itertools
-import json
-
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
-
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
from .client import AsyncClient, Client
-from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError
+from .serv import AsyncServer, AsyncServerConnection
+from .connection import DEFAULT_MAX_CHUNK
+from .exceptions import (
+ ClientError,
+ ServerError,
+ ConnectionClosedError,
+)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index fa042bbe8..7f33099b6 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,13 +10,13 @@ import json
import os
import socket
import sys
-from . import chunkify, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .exceptions import ConnectionClosedError
class AsyncClient(object):
def __init__(self, proto_name, proto_version, logger, timeout=30):
- self.reader = None
- self.writer = None
+ self.socket = None
self.max_chunk = DEFAULT_MAX_CHUNK
self.proto_name = proto_name
self.proto_version = proto_version
@@ -25,7 +25,8 @@ class AsyncClient(object):
async def connect_tcp(self, address, port):
async def connect_sock():
- return await asyncio.open_connection(address, port)
+ reader, writer = await asyncio.open_connection(address, port)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
@@ -40,27 +41,27 @@ class AsyncClient(object):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.connect(os.path.basename(path))
finally:
- os.chdir(cwd)
- return await asyncio.open_unix_connection(sock=sock)
+ os.chdir(cwd)
+ reader, writer = await asyncio.open_unix_connection(sock=sock)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
async def setup_connection(self):
- s = '%s %s\n\n' % (self.proto_name, self.proto_version)
- self.writer.write(s.encode("utf-8"))
- await self.writer.drain()
+ # Send headers
+ await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+ # End of headers
+ await self.socket.send("")
async def connect(self):
- if self.reader is None or self.writer is None:
- (self.reader, self.writer) = await self._connect_sock()
+ if self.socket is None:
+ self.socket = await self._connect_sock()
await self.setup_connection()
async def close(self):
- self.reader = None
-
- if self.writer is not None:
- self.writer.close()
- self.writer = None
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
async def _send_wrapper(self, proc):
count = 0
@@ -71,6 +72,7 @@ class AsyncClient(object):
except (
OSError,
ConnectionError,
+ ConnectionClosedError,
json.JSONDecodeError,
UnicodeDecodeError,
) as e:
@@ -82,49 +84,15 @@ class AsyncClient(object):
await self.close()
count += 1
- async def send_message(self, msg):
- async def get_line():
- try:
- line = await asyncio.wait_for(self.reader.readline(), self.timeout)
- except asyncio.TimeoutError:
- raise ConnectionError("Timed out waiting for server")
-
- if not line:
- raise ConnectionError("Connection closed")
-
- line = line.decode("utf-8")
-
- if not line.endswith("\n"):
- raise ConnectionError("Bad message %r" % (line))
-
- return line
-
+ async def invoke(self, msg):
async def proc():
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode("utf-8"))
- await self.writer.drain()
-
- l = await get_line()
-
- m = json.loads(l)
- if m and "chunk-stream" in m:
- lines = []
- while True:
- l = (await get_line()).rstrip("\n")
- if not l:
- break
- lines.append(l)
-
- m = json.loads("".join(lines))
-
- return m
+ await self.socket.send_message(msg)
+ return await self.socket.recv_message()
return await self._send_wrapper(proc)
async def ping(self):
- return await self.send_message(
- {'ping': {}}
- )
+ return await self.invoke({"ping": {}})
class Client(object):
@@ -142,7 +110,7 @@ class Client(object):
# required (but harmless) with it.
asyncio.set_event_loop(self.loop)
- self._add_methods('connect_tcp', 'ping')
+ self._add_methods("connect_tcp", "ping")
@abc.abstractmethod
def _get_async_client(self):
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
new file mode 100644
index 000000000..c4fd24754
--- /dev/null
+++ b/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,95 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import asyncio
+import itertools
+import json
+from .exceptions import ClientError, ConnectionClosedError
+
+
+# The Python async server defaults to a 64K receive buffer, so we hardcode our
+# maximum chunk size. It would be better if the client and server reported to
+# each other what the maximum chunk sizes were, but that will slow down the
+# connection setup with a round trip delay so I'd rather not do that unless it
+# is necessary
+DEFAULT_MAX_CHUNK = 32 * 1024
+
+
+def chunkify(msg, max_chunk):
+ if len(msg) < max_chunk - 1:
+ yield "".join((msg, "\n"))
+ else:
+ yield "".join((json.dumps({"chunk-stream": None}), "\n"))
+
+ args = [iter(msg)] * (max_chunk - 1)
+ for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
+ yield "".join(itertools.chain(m, "\n"))
+ yield "\n"
+
+
+class StreamConnection(object):
+ def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
+ self.reader = reader
+ self.writer = writer
+ self.timeout = timeout
+ self.max_chunk = max_chunk
+
+ @property
+ def address(self):
+ return self.writer.get_extra_info("peername")
+
+ async def send_message(self, msg):
+ for c in chunkify(json.dumps(msg), self.max_chunk):
+ self.writer.write(c.encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv_message(self):
+ l = await self.recv()
+
+ m = json.loads(l)
+ if not m:
+ return m
+
+ if "chunk-stream" in m:
+ lines = []
+ while True:
+ l = await self.recv()
+ if not l:
+ break
+ lines.append(l)
+
+ m = json.loads("".join(lines))
+
+ return m
+
+ async def send(self, msg):
+ self.writer.write(("%s\n" % msg).encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv(self):
+ if self.timeout < 0:
+ line = await self.reader.readline()
+ else:
+ try:
+ line = await asyncio.wait_for(self.reader.readline(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+
+ if not line:
+ raise ConnectionClosedError("Connection closed")
+
+ line = line.decode("utf-8")
+
+ if not line.endswith("\n"):
+ raise ConnectionError("Bad message %r" % (line))
+
+ return line.rstrip()
+
+ async def close(self):
+ self.reader = None
+ if self.writer is not None:
+ self.writer.close()
+ self.writer = None
diff --git a/lib/bb/asyncrpc/exceptions.py b/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 000000000..a8942b4f0
--- /dev/null
+++ b/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,17 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+
+class ClientError(Exception):
+ pass
+
+
+class ServerError(Exception):
+ pass
+
+
+class ConnectionClosedError(Exception):
+ pass
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index d2de4891b..3e0d0632c 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,241 +12,248 @@ import signal
import socket
import sys
import multiprocessing
-from . import chunkify, DEFAULT_MAX_CHUNK
-
-
-class ClientError(Exception):
- pass
-
-
-class ServerError(Exception):
- pass
+from .connection import StreamConnection
+from .exceptions import ClientError, ServerError, ConnectionClosedError
class AsyncServerConnection(object):
- def __init__(self, reader, writer, proto_name, logger):
- self.reader = reader
- self.writer = writer
+ # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
+ # return message will be automatically be sent back to the client
+ NO_RESPONSE = object()
+
+ def __init__(self, socket, proto_name, logger):
+ self.socket = socket
self.proto_name = proto_name
- self.max_chunk = DEFAULT_MAX_CHUNK
self.handlers = {
- 'chunk-stream': self.handle_chunk,
- 'ping': self.handle_ping,
+ "ping": self.handle_ping,
}
self.logger = logger
+ async def close(self):
+ await self.socket.close()
+
async def process_requests(self):
try:
- self.addr = self.writer.get_extra_info('peername')
- self.logger.debug('Client %r connected' % (self.addr,))
+ self.logger.info("Client %r connected" % (self.socket.address,))
# Read protocol and version
- client_protocol = await self.reader.readline()
+ client_protocol = await self.socket.recv()
if not client_protocol:
return
- (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
+ (client_proto_name, client_proto_version) = client_protocol.split()
if client_proto_name != self.proto_name:
- self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
+ self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
return
- self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
+ self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
if not self.validate_proto_version():
- self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
+ self.logger.debug(
+ "Rejecting invalid protocol version %s" % (client_proto_version)
+ )
return
# Read headers. Currently, no headers are implemented, so look for
# an empty line to signal the end of the headers
while True:
- line = await self.reader.readline()
- if not line:
- return
-
- line = line.decode('utf-8').rstrip()
- if not line:
+ header = await self.socket.recv()
+ if not header:
break
# Handle messages
while True:
- d = await self.read_message()
+ d = await self.socket.recv_message()
if d is None:
break
- await self.dispatch_message(d)
- await self.writer.drain()
- except ClientError as e:
+ response = await self.dispatch_message(d)
+ if response is not self.NO_RESPONSE:
+ await self.socket.send_message(response)
+
+ except ConnectionClosedError as e:
+ self.logger.info(str(e))
+ except (ClientError, ConnectionError) as e:
self.logger.error(str(e))
finally:
- self.writer.close()
+ await self.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- await self.handlers[k](msg[k])
- return
+ self.logger.debug("Handling %s" % k)
+ return await self.handlers[k](msg[k])
raise ClientError("Unrecognized command %r" % msg)
- def write_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode('utf-8'))
+ async def handle_ping(self, request):
+ return {"alive": True}
- async def read_message(self):
- l = await self.reader.readline()
- if not l:
- return None
- try:
- message = l.decode('utf-8')
+class StreamServer(object):
+ def __init__(self, handler, logger):
+ self.handler = handler
+ self.logger = logger
+ self.closed = False
- if not message.endswith('\n'):
- return None
+ async def handle_stream_client(self, reader, writer):
+ # writer.transport.set_write_buffer_limits(0)
+ socket = StreamConnection(reader, writer, -1)
+ if self.closed:
+ await socket.close()
+ return
+
+ await self.handler(socket)
+
+ async def stop(self):
+ self.closed = True
+
+
+class TCPStreamServer(StreamServer):
+ def __init__(self, host, port, handler, logger):
+ super().__init__(handler, logger)
+ self.host = host
+ self.port = port
+
+ def start(self, loop):
+ self.server = loop.run_until_complete(
+ asyncio.start_server(self.handle_stream_client, self.host, self.port)
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+ # Newer python does this automatically. Do it manually here for
+ # maximum compatibility
+ s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+ s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ await super().stop()
+ self.server.close()
+
+ def cleanup(self):
+ pass
- return json.loads(message)
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % message)
- raise e
- async def handle_chunk(self, request):
- lines = []
- try:
- while True:
- l = await self.reader.readline()
- l = l.rstrip(b"\n").decode("utf-8")
- if not l:
- break
- lines.append(l)
+class UnixStreamServer(StreamServer):
+ def __init__(self, path, handler, logger):
+ super().__init__(handler, logger)
+ self.path = path
- msg = json.loads(''.join(lines))
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % lines)
- raise e
+ def start(self, loop):
+ cwd = os.getcwd()
+ try:
+ # Work around path length limits in AF_UNIX
+ os.chdir(os.path.dirname(self.path))
+ self.server = loop.run_until_complete(
+ asyncio.start_unix_server(
+ self.handle_stream_client, os.path.basename(self.path)
+ )
+ )
+ finally:
+ os.chdir(cwd)
- if 'chunk-stream' in msg:
- raise ClientError("Nested chunks are not allowed")
+ self.logger.debug("Listening on %r" % self.path)
+ self.address = "unix://%s" % os.path.abspath(self.path)
+ return [self.server.wait_closed()]
- await self.dispatch_message(msg)
+ async def stop(self):
+ await super().stop()
+ self.server.close()
- async def handle_ping(self, request):
- response = {'alive': True}
- self.write_message(response)
+ def cleanup(self):
+ os.unlink(self.path)
class AsyncServer(object):
def __init__(self, logger):
- self._cleanup_socket = None
self.logger = logger
- self.start = None
- self.address = None
self.loop = None
+ self.run_tasks = []
def start_tcp_server(self, host, port):
- def start_tcp():
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port)
- )
-
- for s in self.server.sockets:
- self.logger.debug('Listening on %r' % (s.getsockname(),))
- # Newer python does this automatically. Do it manually here for
- # maximum compatibility
- s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
- s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
-
- # Enable keep alives. This prevents broken client connections
- # from persisting on the server for long periods of time.
- s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
-
- name = self.server.sockets[0].getsockname()
- if self.server.sockets[0].family == socket.AF_INET6:
- self.address = "[%s]:%d" % (name[0], name[1])
- else:
- self.address = "%s:%d" % (name[0], name[1])
-
- self.start = start_tcp
+ self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
def start_unix_server(self, path):
- def cleanup():
- os.unlink(path)
-
- def start_unix():
- cwd = os.getcwd()
- try:
- # Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path))
- )
- finally:
- os.chdir(cwd)
-
- self.logger.debug('Listening on %r' % path)
+ self.server = UnixStreamServer(path, self._client_handler, self.logger)
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
-
- self.start = start_unix
-
- @abc.abstractmethod
- def accept_client(self, reader, writer):
- pass
-
- async def handle_client(self, reader, writer):
- # writer.transport.set_write_buffer_limits(0)
+ async def _client_handler(self, socket):
try:
- client = self.accept_client(reader, writer)
+ client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
- self.logger.error('Error from client: %s' % str(e), exc_info=True)
+
+ self.logger.error("Error from client: %s" % str(e), exc_info=True)
traceback.print_exc()
- writer.close()
- self.logger.debug('Client disconnected')
+ await socket.close()
+ self.logger.debug("Client disconnected")
- def run_loop_forever(self):
- try:
- self.loop.run_forever()
- except KeyboardInterrupt:
- pass
+ @abc.abstractmethod
+ def accept_client(self, socket):
+ pass
+
+ async def stop(self):
+ self.logger.debug("Stopping server")
+ await self.server.stop()
+
+ def start(self):
+ tasks = self.server.start(self.loop)
+ self.address = self.server.address
+ return tasks
def signal_handler(self):
self.logger.debug("Got exit signal")
- self.loop.stop()
+ self.loop.create_task(self.stop())
- def _serve_forever(self):
+ def _serve_forever(self, tasks):
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
- self.run_loop_forever()
- self.server.close()
+ self.loop.run_until_complete(asyncio.gather(*tasks))
- self.loop.run_until_complete(self.server.wait_closed())
- self.logger.debug('Server shutting down')
+ self.logger.debug("Server shutting down")
finally:
- if self._cleanup_socket is not None:
- self._cleanup_socket()
+ self.server.cleanup()
def serve_forever(self):
"""
Serve requests in the current process
"""
+ self._create_loop()
+ tasks = self.start()
+ self._serve_forever(tasks)
+ self.loop.close()
+
+ def _create_loop(self):
# Create loop and override any loop that may have existed in
# a parent process. It is possible that the usecases of
# serve_forever might be constrained enough to allow using
# get_event_loop here, but better safe than sorry for now.
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
- self.start()
- self._serve_forever()
def serve_as_process(self, *, prefunc=None, args=()):
"""
Serve requests in a child process
"""
+
def run(queue):
# Create loop and override any loop that may have existed
# in a parent process. Without doing this and instead
@@ -259,18 +266,19 @@ class AsyncServer(object):
# more general, though, as any potential use of asyncio in
# Cooker could create a loop that needs to replaced in this
# new process.
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
+ self._create_loop()
try:
- self.start()
+ self.address = None
+ tasks = self.start()
finally:
+ # Always put the server address to wake up the parent task
queue.put(self.address)
queue.close()
if prefunc is not None:
prefunc(self, *args)
- self._serve_forever()
+ self._serve_forever(tasks)
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9cb3fd57a..3a4018353 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -15,13 +15,6 @@ UNIX_PREFIX = "unix://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
("taskhash", "TEXT NOT NULL", "UNIQUE"),
@@ -102,20 +95,6 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
-def chunkify(msg, max_chunk):
- if len(msg) < max_chunk - 1:
- yield ''.join((msg, "\n"))
- else:
- yield ''.join((json.dumps({
- 'chunk-stream': None
- }), "\n"))
-
- args = [iter(msg)] * (max_chunk - 1)
- for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
- yield ''.join(itertools.chain(m, "\n"))
- yield "\n"
-
-
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
from . import server
db = setup_database(dbname, sync=sync)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index f676d267f..5f7d22ab1 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -28,24 +28,24 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def send_stream(self, msg):
async def proc():
- self.writer.write(("%s\n" % msg).encode("utf-8"))
- await self.writer.drain()
- l = await self.reader.readline()
- if not l:
- raise ConnectionError("Connection closed")
- return l.decode("utf-8").rstrip()
+ await self.socket.send(msg)
+ return await self.socket.recv()
return await self._send_wrapper(proc)
async def _set_mode(self, new_mode):
+ async def stream_to_normal():
+ await self.socket.send("END")
+ return await self.socket.recv()
+
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
- r = await self.send_stream("END")
+ r = await self._send_wrapper(stream_to_normal)
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
- r = await self.send_message({"get-stream": None})
+ r = await self.invoke({"get-stream": None})
if r != "ok":
- raise ConnectionError("Bad response from server %r" % r)
+ raise ConnectionError("Unable to transition to stream mode: Bad response from server %r" % r)
elif new_mode != self.mode:
raise Exception(
"Undefined mode transition %r -> %r" % (self.mode, new_mode)
@@ -67,7 +67,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["method"] = method
m["outhash"] = outhash
m["unihash"] = unihash
- return await self.send_message({"report": m})
+ return await self.invoke({"report": m})
async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
await self._set_mode(self.MODE_NORMAL)
@@ -75,39 +75,39 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["taskhash"] = taskhash
m["method"] = method
m["unihash"] = unihash
- return await self.send_message({"report-equiv": m})
+ return await self.invoke({"report-equiv": m})
async def get_taskhash(self, method, taskhash, all_properties=False):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
)
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message(
+ return await self.invoke(
{"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
)
async def get_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"get-stats": None})
+ return await self.invoke({"get-stats": None})
async def reset_stats(self):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"reset-stats": None})
+ return await self.invoke({"reset-stats": None})
async def backfill_wait(self):
await self._set_mode(self.MODE_NORMAL)
- return (await self.send_message({"backfill-wait": None}))["tasks"]
+ return (await self.invoke({"backfill-wait": None}))["tasks"]
async def remove(self, where):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"remove": {"where": where}})
+ return await self.invoke({"remove": {"where": where}})
async def clean_unused(self, max_age):
await self._set_mode(self.MODE_NORMAL)
- return await self.send_message({"clean-unused": {"max_age_seconds": max_age}})
+ return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
class Client(bb.asyncrpc.Client):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 45bf476bf..13b754805 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -165,8 +165,8 @@ class ServerCursor(object):
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(reader, writer, 'OEHASHEQUIV', logger)
+ def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
+ super().__init__(socket, 'OEHASHEQUIV', logger)
self.db = db
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
@@ -209,12 +209,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in msg:
logger.debug('Handling %s' % k)
if 'stream' in k:
- await self.handlers[k](msg[k])
+ return await self.handlers[k](msg[k])
else:
with self.request_stats.start_sample() as self.request_sample, \
self.request_sample.measure():
- await self.handlers[k](msg[k])
- return
+ return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
@@ -224,9 +223,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
fetch_all = request.get('all', False)
with closing(self.db.cursor()) as cursor:
- d = await self.get_unihash(cursor, method, taskhash, fetch_all)
-
- self.write_message(d)
+ return await self.get_unihash(cursor, method, taskhash, fetch_all)
async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
d = None
@@ -274,9 +271,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
with_unihash = request.get("with_unihash", True)
with closing(self.db.cursor()) as cursor:
- d = await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
-
- self.write_message(d)
+ return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
d = None
@@ -334,14 +329,14 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
async def handle_get_stream(self, request):
- self.write_message('ok')
+ await self.socket.send_message("ok")
while True:
upstream = None
- l = await self.reader.readline()
+ l = await self.socket.recv()
if not l:
- return
+ break
try:
# This inner loop is very sensitive and must be as fast as
@@ -352,10 +347,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- l = l.decode('utf-8').rstrip()
if l == 'END':
- self.writer.write('ok\n'.encode('utf-8'))
- return
+ break
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
@@ -366,29 +359,30 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
cursor.close()
if row is not None:
- msg = ('%s\n' % row['unihash']).encode('utf-8')
+ msg = row['unihash']
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
- msg = ("%s\n" % upstream).encode("utf-8")
+ msg = upstream
else:
- msg = "\n".encode("utf-8")
+ msg = ""
else:
- msg = '\n'.encode('utf-8')
+ msg = ""
- self.writer.write(msg)
+ await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
- await self.writer.drain()
-
# Post to the backfill queue after writing the result to minimize
# the turn around time on a request
if upstream is not None:
await self.backfill_queue.put((method, taskhash))
+ await self.socket.send("ok")
+ return self.NO_RESPONSE
+
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
outhash_data = {
@@ -468,7 +462,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
'unihash': unihash,
}
- self.write_message(d)
+ return d
async def handle_equivreport(self, data):
with closing(self.db.cursor()) as cursor:
@@ -491,30 +485,28 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
- self.write_message(d)
+ return d
async def handle_get_stats(self, request):
- d = {
+ return {
'requests': self.request_stats.todict(),
}
- self.write_message(d)
-
async def handle_reset_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.request_stats.reset()
- self.write_message(d)
+ return d
async def handle_backfill_wait(self, request):
d = {
'tasks': self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
- self.write_message(d)
+ return d
async def handle_remove(self, request):
condition = request["where"]
@@ -541,7 +533,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
self.db.commit()
- self.write_message({"count": count})
+ return {"count": count}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
@@ -558,7 +550,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
count = cursor.rowcount
- self.write_message({"count": count})
+ return {"count": count}
def query_equivalent(self, cursor, method, taskhash):
# This is part of the inner loop and must be as fast as possible
@@ -583,41 +575,33 @@ class Server(bb.asyncrpc.AsyncServer):
self.db = db
self.upstream = upstream
self.read_only = read_only
+ self.backfill_queue = None
- def accept_client(self, reader, writer):
- return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ def accept_client(self, socket):
+ return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
- @contextmanager
- def _backfill_worker(self):
- async def backfill_worker_task():
- client = await create_async_client(self.upstream)
- try:
- while True:
- item = await self.backfill_queue.get()
- if item is None:
- self.backfill_queue.task_done()
- break
- method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ async def backfill_worker_task(self):
+ client = await create_async_client(self.upstream)
+ try:
+ while True:
+ item = await self.backfill_queue.get()
+ if item is None:
self.backfill_queue.task_done()
- finally:
- await client.close()
+ break
+ method, taskhash = item
+ await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ self.backfill_queue.task_done()
+ finally:
+ await client.close()
- async def join_worker(worker):
+ def start(self):
+ tasks = super().start()
+ if self.upstream:
+ self.backfill_queue = asyncio.Queue()
+ tasks += [self.backfill_worker_task()]
+ return tasks
+
+ async def stop(self):
+ if self.backfill_queue is not None:
await self.backfill_queue.put(None)
- await worker
-
- if self.upstream is not None:
- worker = asyncio.ensure_future(backfill_worker_task())
- try:
- yield
- finally:
- self.loop.run_until_complete(join_worker(worker))
- else:
- yield
-
- def run_loop_forever(self):
- self.backfill_queue = asyncio.Queue()
-
- with self._backfill_worker():
- super().run_loop_forever()
+ await super().stop()
diff --git a/lib/prserv/client.py b/lib/prserv/client.py
index 69ab7a4ac..6b81356fa 100644
--- a/lib/prserv/client.py
+++ b/lib/prserv/client.py
@@ -14,28 +14,28 @@ class PRAsyncClient(bb.asyncrpc.AsyncClient):
super().__init__('PRSERVICE', '1.0', logger)
async def getPR(self, version, pkgarch, checksum):
- response = await self.send_message(
+ response = await self.invoke(
{'get-pr': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum}}
)
if response:
return response['value']
async def importone(self, version, pkgarch, checksum, value):
- response = await self.send_message(
+ response = await self.invoke(
{'import-one': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'value': value}}
)
if response:
return response['value']
async def export(self, version, pkgarch, checksum, colinfo):
- response = await self.send_message(
+ response = await self.invoke(
{'export': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'colinfo': colinfo}}
)
if response:
return (response['metainfo'], response['datainfo'])
async def is_readonly(self):
- response = await self.send_message(
+ response = await self.invoke(
{'is-readonly': {}}
)
if response:
diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py
index c686b2065..ea7933164 100644
--- a/lib/prserv/serv.py
+++ b/lib/prserv/serv.py
@@ -20,8 +20,8 @@ PIDPREFIX = "/tmp/PRServer_%s_%s.pid"
singleton = None
class PRServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, reader, writer, table, read_only):
- super().__init__(reader, writer, 'PRSERVICE', logger)
+ def __init__(self, socket, table, read_only):
+ super().__init__(socket, 'PRSERVICE', logger)
self.handlers.update({
'get-pr': self.handle_get_pr,
'import-one': self.handle_import_one,
@@ -36,12 +36,12 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
async def dispatch_message(self, msg):
try:
- await super().dispatch_message(msg)
+ return await super().dispatch_message(msg)
except:
self.table.sync()
raise
-
- self.table.sync_if_dirty()
+ else:
+ self.table.sync_if_dirty()
async def handle_get_pr(self, request):
version = request['version']
@@ -57,7 +57,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
except sqlite3.Error as exc:
logger.error(str(exc))
- self.write_message(response)
+ return response
async def handle_import_one(self, request):
response = None
@@ -71,7 +71,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
if value is not None:
response = {'value': value}
- self.write_message(response)
+ return response
async def handle_export(self, request):
version = request['version']
@@ -85,12 +85,10 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
logger.error(str(exc))
metainfo = datainfo = None
- response = {'metainfo': metainfo, 'datainfo': datainfo}
- self.write_message(response)
+ return {'metainfo': metainfo, 'datainfo': datainfo}
async def handle_is_readonly(self, request):
- response = {'readonly': self.read_only}
- self.write_message(response)
+ return {'readonly': self.read_only}
class PRServer(bb.asyncrpc.AsyncServer):
def __init__(self, dbfile, read_only=False):
@@ -99,20 +97,23 @@ class PRServer(bb.asyncrpc.AsyncServer):
self.table = None
self.read_only = read_only
- def accept_client(self, reader, writer):
- return PRServerClient(reader, writer, self.table, self.read_only)
+ def accept_client(self, socket):
+ return PRServerClient(socket, self.table, self.read_only)
- def _serve_forever(self):
+ def start(self):
+ tasks = super().start()
self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only)
self.table = self.db["PRMAIN"]
logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" %
(self.dbfile, self.address, str(os.getpid())))
- super()._serve_forever()
+ return tasks
+ async def stop(self):
self.table.sync_if_dirty()
self.db.disconnect()
+ await super().stop()
def signal_handler(self):
super().signal_handler()