aboutsummaryrefslogtreecommitdiffstats
path: root/lib
diff options
context:
space:
mode:
authorJoshua Watt <JPEWhacker@gmail.com>2021-08-19 12:46:41 -0400
committerRichard Purdie <richard.purdie@linuxfoundation.org>2021-08-23 08:30:16 +0100
commit8555869cde39f9e9a9ced5a3e5788209640f6d50 (patch)
tree1f64598a3e386baeabb1fdd7a11468b2d2e5cf15 /lib
parent076baf4fbd328d247508fd399866a397eb34f67e (diff)
downloadbitbake-8555869cde39f9e9a9ced5a3e5788209640f6d50.tar.gz
bitbake: asyncrpc: Defer all asyncio to child process
Reworks the async I/O API so that the async loop is only created in the child process. This requires deferring the creation of the server until the child process and a queue to transfer the bound address back to the parent process Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> [small loop -> self.loop fix in serv.py] Signed-off-by: Scott Murray <scott.murray@konsulko.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'lib')
-rw-r--r--lib/bb/asyncrpc/serv.py118
-rw-r--r--lib/hashserv/server.py4
2 files changed, 74 insertions, 48 deletions
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index 4084f300d..45628698b 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -131,53 +131,58 @@ class AsyncServerConnection(object):
class AsyncServer(object):
- def __init__(self, logger, loop=None):
- if loop is None:
- self.loop = asyncio.new_event_loop()
- self.close_loop = True
- else:
- self.loop = loop
- self.close_loop = False
-
+ def __init__(self, logger):
self._cleanup_socket = None
self.logger = logger
+ self.start = None
+ self.address = None
+
+ @property
+ def loop(self):
+ return asyncio.get_event_loop()
def start_tcp_server(self, host, port):
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port, loop=self.loop)
- )
-
- for s in self.server.sockets:
- self.logger.debug('Listening on %r' % (s.getsockname(),))
- # Newer python does this automatically. Do it manually here for
- # maximum compatibility
- s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
- s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
-
- name = self.server.sockets[0].getsockname()
- if self.server.sockets[0].family == socket.AF_INET6:
- self.address = "[%s]:%d" % (name[0], name[1])
- else:
- self.address = "%s:%d" % (name[0], name[1])
+ def start_tcp():
+ self.server = self.loop.run_until_complete(
+ asyncio.start_server(self.handle_client, host, port)
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug('Listening on %r' % (s.getsockname(),))
+ # Newer python does this automatically. Do it manually here for
+ # maximum compatibility
+ s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+ s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "%s:%d" % (name[0], name[1])
+
+ self.start = start_tcp
def start_unix_server(self, path):
def cleanup():
os.unlink(path)
- cwd = os.getcwd()
- try:
- # Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
- )
- finally:
- os.chdir(cwd)
+ def start_unix():
+ cwd = os.getcwd()
+ try:
+ # Work around path length limits in AF_UNIX
+ os.chdir(os.path.dirname(path))
+ self.server = self.loop.run_until_complete(
+ asyncio.start_unix_server(self.handle_client, os.path.basename(path))
+ )
+ finally:
+ os.chdir(cwd)
- self.logger.debug('Listening on %r' % path)
+ self.logger.debug('Listening on %r' % path)
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
+ self._cleanup_socket = cleanup
+ self.address = "unix://%s" % os.path.abspath(path)
+
+ self.start = start_unix
@abc.abstractmethod
def accept_client(self, reader, writer):
@@ -205,8 +210,7 @@ class AsyncServer(object):
self.logger.debug("Got exit signal")
self.loop.stop()
- def serve_forever(self):
- asyncio.set_event_loop(self.loop)
+ def _serve_forever(self):
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
@@ -217,28 +221,50 @@ class AsyncServer(object):
self.loop.run_until_complete(self.server.wait_closed())
self.logger.debug('Server shutting down')
finally:
- if self.close_loop:
- if sys.version_info >= (3, 6):
- self.loop.run_until_complete(self.loop.shutdown_asyncgens())
- self.loop.close()
-
if self._cleanup_socket is not None:
self._cleanup_socket()
+ def serve_forever(self):
+ """
+ Serve requests in the current process
+ """
+ self.start()
+ self._serve_forever()
+
def serve_as_process(self, *, prefunc=None, args=()):
- def run():
+ """
+ Serve requests in a child process
+ """
+ def run(queue):
+ try:
+ self.start()
+ finally:
+ queue.put(self.address)
+ queue.close()
+
if prefunc is not None:
prefunc(self, *args)
- self.serve_forever()
+
+ self._serve_forever()
+
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+
+ queue = multiprocessing.Queue()
# Temporarily block SIGTERM. The server process will inherit this
# block which will ensure it doesn't receive the SIGTERM until the
# handler is ready for it
mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
try:
- self.process = multiprocessing.Process(target=run)
+ self.process = multiprocessing.Process(target=run, args=(queue,))
self.process.start()
+ self.address = queue.get()
+ queue.close()
+ queue.join_thread()
+
return self.process
finally:
signal.pthread_sigmask(signal.SIG_SETMASK, mask)
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 8e8498973..a059e5211 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -410,11 +410,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db, loop=None, upstream=None, read_only=False):
+ def __init__(self, db, upstream=None, read_only=False):
if upstream and read_only:
raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
- super().__init__(logger, loop)
+ super().__init__(logger)
self.request_stats = Stats()
self.db = db