diff options
Diffstat (limited to 'lib/bb/asyncrpc/client.py')
-rw-r--r-- | lib/bb/asyncrpc/client.py | 276 |
1 files changed, 222 insertions, 54 deletions
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py index 4cdad9ac3..a350b4fb1 100644 --- a/lib/bb/asyncrpc/client.py +++ b/lib/bb/asyncrpc/client.py @@ -1,4 +1,6 @@ # +# Copyright BitBake Contributors +# # SPDX-License-Identifier: GPL-2.0-only # @@ -7,46 +9,127 @@ import asyncio import json import os import socket -from . import chunkify, DEFAULT_MAX_CHUNK - +import sys +import re +import contextlib +from threading import Thread +from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK +from .exceptions import ConnectionClosedError, InvokeError + +UNIX_PREFIX = "unix://" +WS_PREFIX = "ws://" +WSS_PREFIX = "wss://" + +ADDR_TYPE_UNIX = 0 +ADDR_TYPE_TCP = 1 +ADDR_TYPE_WS = 2 + +def parse_address(addr): + if addr.startswith(UNIX_PREFIX): + return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) + elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX): + return (ADDR_TYPE_WS, (addr,)) + else: + m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr) + if m is not None: + host = m.group("host") + port = m.group("port") + else: + host, port = addr.split(":") + + return (ADDR_TYPE_TCP, (host, int(port))) class AsyncClient(object): - def __init__(self, proto_name, proto_version, logger): - self.reader = None - self.writer = None + def __init__( + self, + proto_name, + proto_version, + logger, + timeout=30, + server_headers=False, + headers={}, + ): + self.socket = None self.max_chunk = DEFAULT_MAX_CHUNK self.proto_name = proto_name self.proto_version = proto_version self.logger = logger + self.timeout = timeout + self.needs_server_headers = server_headers + self.server_headers = {} + self.headers = headers async def connect_tcp(self, address, port): async def connect_sock(): - return await asyncio.open_connection(address, port) + reader, writer = await asyncio.open_connection(address, port) + return StreamConnection(reader, writer, self.timeout, self.max_chunk) self._connect_sock = connect_sock async def connect_unix(self, path): async def connect_sock(): - return await asyncio.open_unix_connection(path) + # AF_UNIX has path length issues so chdir here to workaround + cwd = os.getcwd() + try: + os.chdir(os.path.dirname(path)) + # The socket must be opened synchronously so that CWD doesn't get + # changed out from underneath us so we pass as a sock into asyncio + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) + sock.connect(os.path.basename(path)) + finally: + os.chdir(cwd) + reader, writer = await asyncio.open_unix_connection(sock=sock) + return StreamConnection(reader, writer, self.timeout, self.max_chunk) + + self._connect_sock = connect_sock + + async def connect_websocket(self, uri): + import websockets + + async def connect_sock(): + websocket = await websockets.connect(uri, ping_interval=None) + return WebsocketConnection(websocket, self.timeout) self._connect_sock = connect_sock async def setup_connection(self): - s = '%s %s\n\n' % (self.proto_name, self.proto_version) - self.writer.write(s.encode("utf-8")) - await self.writer.drain() + # Send headers + await self.socket.send("%s %s" % (self.proto_name, self.proto_version)) + await self.socket.send( + "needs-headers: %s" % ("true" if self.needs_server_headers else "false") + ) + for k, v in self.headers.items(): + await self.socket.send("%s: %s" % (k, v)) + + # End of headers + await self.socket.send("") + + self.server_headers = {} + if self.needs_server_headers: + while True: + line = await self.socket.recv() + if not line: + # End headers + break + tag, value = line.split(":", 1) + self.server_headers[tag.lower()] = value.strip() + + async def get_header(self, tag, default): + await self.connect() + return self.server_headers.get(tag, default) async def connect(self): - if self.reader is None or self.writer is None: - (self.reader, self.writer) = await self._connect_sock() + if self.socket is None: + self.socket = await self._connect_sock() await self.setup_connection() - async def close(self): - self.reader = None + async def disconnect(self): + if self.socket is not None: + await self.socket.close() + self.socket = None - if self.writer is not None: - self.writer.close() - self.writer = None + async def close(self): + await self.disconnect() async def _send_wrapper(self, proc): count = 0 @@ -57,6 +140,7 @@ class AsyncClient(object): except ( OSError, ConnectionError, + ConnectionClosedError, json.JSONDecodeError, UnicodeDecodeError, ) as e: @@ -68,40 +152,27 @@ class AsyncClient(object): await self.close() count += 1 - async def send_message(self, msg): - async def get_line(): - line = await self.reader.readline() - if not line: - raise ConnectionError("Connection closed") - - line = line.decode("utf-8") - - if not line.endswith("\n"): - raise ConnectionError("Bad message %r" % msg) - - return line + def check_invoke_error(self, msg): + if isinstance(msg, dict) and "invoke-error" in msg: + raise InvokeError(msg["invoke-error"]["message"]) + async def invoke(self, msg): async def proc(): - for c in chunkify(json.dumps(msg), self.max_chunk): - self.writer.write(c.encode("utf-8")) - await self.writer.drain() - - l = await get_line() + await self.socket.send_message(msg) + return await self.socket.recv_message() - m = json.loads(l) - if m and "chunk-stream" in m: - lines = [] - while True: - l = (await get_line()).rstrip("\n") - if not l: - break - lines.append(l) + result = await self._send_wrapper(proc) + self.check_invoke_error(result) + return result - m = json.loads("".join(lines)) + async def ping(self): + return await self.invoke({"ping": {}}) - return m + async def __aenter__(self): + return self - return await self._send_wrapper(proc) + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() class Client(object): @@ -109,7 +180,17 @@ class Client(object): self.client = self._get_async_client() self.loop = asyncio.new_event_loop() - self._add_methods('connect_tcp', 'close') + # Override any pre-existing loop. + # Without this, the PR server export selftest triggers a hang + # when running with Python 3.7. The drawback is that there is + # potential for issues if the PR and hash equiv (or some new) + # clients need to both be instantiated in the same process. + # This should be revisited if/when Python 3.9 becomes the + # minimum required version for BitBake, as it seems not + # required (but harmless) with it. + asyncio.set_event_loop(self.loop) + + self._add_methods("connect_tcp", "ping") @abc.abstractmethod def _get_async_client(self): @@ -127,14 +208,8 @@ class Client(object): setattr(self, m, self._get_downcall_wrapper(downcall)) def connect_unix(self, path): - # AF_UNIX has path length issues so chdir here to workaround - cwd = os.getcwd() - try: - os.chdir(os.path.dirname(path)) - self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path))) - self.loop.run_until_complete(self.client.connect()) - finally: - os.chdir(cwd) + self.loop.run_until_complete(self.client.connect_unix(path)) + self.loop.run_until_complete(self.client.connect()) @property def max_chunk(self): @@ -143,3 +218,96 @@ class Client(object): @max_chunk.setter def max_chunk(self, value): self.client.max_chunk = value + + def disconnect(self): + self.loop.run_until_complete(self.client.close()) + + def close(self): + if self.loop: + self.loop.run_until_complete(self.client.close()) + if sys.version_info >= (3, 6): + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() + self.loop = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + return False + + +class ClientPool(object): + def __init__(self, max_clients): + self.avail_clients = [] + self.num_clients = 0 + self.max_clients = max_clients + self.loop = None + self.client_condition = None + + @abc.abstractmethod + async def _new_client(self): + raise NotImplementedError("Must be implemented in derived class") + + def close(self): + if self.client_condition: + self.client_condition = None + + if self.loop: + self.loop.run_until_complete(self.__close_clients()) + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() + self.loop = None + + def run_tasks(self, tasks): + if not self.loop: + self.loop = asyncio.new_event_loop() + + thread = Thread(target=self.__thread_main, args=(tasks,)) + thread.start() + thread.join() + + @contextlib.asynccontextmanager + async def get_client(self): + async with self.client_condition: + if self.avail_clients: + client = self.avail_clients.pop() + elif self.num_clients < self.max_clients: + self.num_clients += 1 + client = await self._new_client() + else: + while not self.avail_clients: + await self.client_condition.wait() + client = self.avail_clients.pop() + + try: + yield client + finally: + async with self.client_condition: + self.avail_clients.append(client) + self.client_condition.notify() + + def __thread_main(self, tasks): + async def process_task(task): + async with self.get_client() as client: + await task(client) + + asyncio.set_event_loop(self.loop) + if not self.client_condition: + self.client_condition = asyncio.Condition() + tasks = [process_task(t) for t in tasks] + self.loop.run_until_complete(asyncio.gather(*tasks)) + + async def __close_clients(self): + for c in self.avail_clients: + await c.close() + self.avail_clients = [] + self.num_clients = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + return False |