aboutsummaryrefslogtreecommitdiffstats
path: root/lib/bb/asyncrpc/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/bb/asyncrpc/client.py')
-rw-r--r--lib/bb/asyncrpc/client.py78
1 files changed, 23 insertions, 55 deletions
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index fa042bbe8..7f33099b6 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,13 +10,13 @@ import json
import os
import socket
import sys
-from . import chunkify, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .exceptions import ConnectionClosedError
class AsyncClient(object):
def __init__(self, proto_name, proto_version, logger, timeout=30):
- self.reader = None
- self.writer = None
+ self.socket = None
self.max_chunk = DEFAULT_MAX_CHUNK
self.proto_name = proto_name
self.proto_version = proto_version
@@ -25,7 +25,8 @@ class AsyncClient(object):
async def connect_tcp(self, address, port):
async def connect_sock():
- return await asyncio.open_connection(address, port)
+ reader, writer = await asyncio.open_connection(address, port)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
@@ -40,27 +41,27 @@ class AsyncClient(object):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.connect(os.path.basename(path))
finally:
- os.chdir(cwd)
- return await asyncio.open_unix_connection(sock=sock)
+ os.chdir(cwd)
+ reader, writer = await asyncio.open_unix_connection(sock=sock)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
async def setup_connection(self):
- s = '%s %s\n\n' % (self.proto_name, self.proto_version)
- self.writer.write(s.encode("utf-8"))
- await self.writer.drain()
+ # Send headers
+ await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+ # End of headers
+ await self.socket.send("")
async def connect(self):
- if self.reader is None or self.writer is None:
- (self.reader, self.writer) = await self._connect_sock()
+ if self.socket is None:
+ self.socket = await self._connect_sock()
await self.setup_connection()
async def close(self):
- self.reader = None
-
- if self.writer is not None:
- self.writer.close()
- self.writer = None
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
async def _send_wrapper(self, proc):
count = 0
@@ -71,6 +72,7 @@ class AsyncClient(object):
except (
OSError,
ConnectionError,
+ ConnectionClosedError,
json.JSONDecodeError,
UnicodeDecodeError,
) as e:
@@ -82,49 +84,15 @@ class AsyncClient(object):
await self.close()
count += 1
- async def send_message(self, msg):
- async def get_line():
- try:
- line = await asyncio.wait_for(self.reader.readline(), self.timeout)
- except asyncio.TimeoutError:
- raise ConnectionError("Timed out waiting for server")
-
- if not line:
- raise ConnectionError("Connection closed")
-
- line = line.decode("utf-8")
-
- if not line.endswith("\n"):
- raise ConnectionError("Bad message %r" % (line))
-
- return line
-
+ async def invoke(self, msg):
async def proc():
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode("utf-8"))
- await self.writer.drain()
-
- l = await get_line()
-
- m = json.loads(l)
- if m and "chunk-stream" in m:
- lines = []
- while True:
- l = (await get_line()).rstrip("\n")
- if not l:
- break
- lines.append(l)
-
- m = json.loads("".join(lines))
-
- return m
+ await self.socket.send_message(msg)
+ return await self.socket.recv_message()
return await self._send_wrapper(proc)
async def ping(self):
- return await self.send_message(
- {'ping': {}}
- )
+ return await self.invoke({"ping": {}})
class Client(object):
@@ -142,7 +110,7 @@ class Client(object):
# required (but harmless) with it.
asyncio.set_event_loop(self.loop)
- self._add_methods('connect_tcp', 'ping')
+ self._add_methods("connect_tcp", "ping")
@abc.abstractmethod
def _get_async_client(self):