diff options
-rw-r--r-- | lib/bb/asyncrpc/serv.py | 21 | ||||
-rw-r--r-- | lib/bb/cooker.py | 3 | ||||
-rw-r--r-- | lib/hashserv/tests.py | 54 |
3 files changed, 64 insertions, 14 deletions
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py index ef20cb71d..4084f300d 100644 --- a/lib/bb/asyncrpc/serv.py +++ b/lib/bb/asyncrpc/serv.py @@ -9,6 +9,7 @@ import os import signal import socket import sys +import multiprocessing from . import chunkify, DEFAULT_MAX_CHUNK @@ -201,12 +202,14 @@ class AsyncServer(object): pass def signal_handler(self): + self.logger.debug("Got exit signal") self.loop.stop() def serve_forever(self): asyncio.set_event_loop(self.loop) try: self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) + signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) self.run_loop_forever() self.server.close() @@ -221,3 +224,21 @@ class AsyncServer(object): if self._cleanup_socket is not None: self._cleanup_socket() + + def serve_as_process(self, *, prefunc=None, args=()): + def run(): + if prefunc is not None: + prefunc(self, *args) + self.serve_forever() + + # Temporarily block SIGTERM. The server process will inherit this + # block which will ensure it doesn't receive the SIGTERM until the + # handler is ready for it + mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) + try: + self.process = multiprocessing.Process(target=run) + self.process.start() + + return self.process + finally: + signal.pthread_sigmask(signal.SIG_SETMASK, mask) diff --git a/lib/bb/cooker.py b/lib/bb/cooker.py index 39e10e613..b2d69c28c 100644 --- a/lib/bb/cooker.py +++ b/lib/bb/cooker.py @@ -390,8 +390,7 @@ class BBCooker: dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db" self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR") self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False) - self.hashserv.process = multiprocessing.Process(target=self.hashserv.serve_forever) - self.hashserv.process.start() + self.hashserv.serve_as_process() self.data.setVar("BB_HASHSERVE", self.hashservaddr) self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr) self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr) diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py index e2b762dbf..e851535c5 100644 --- a/lib/hashserv/tests.py +++ b/lib/hashserv/tests.py @@ -15,28 +15,32 @@ import tempfile import threading import unittest import socket +import time +import signal -def _run_server(server, idx): - # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', - # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') +def server_prefunc(server, idx): + logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', + format='%(levelname)s %(filename)s:%(lineno)d %(message)s') + server.logger.debug("Running server %d" % idx) sys.stdout = open('bbhashserv-%d.log' % idx, 'w') sys.stderr = sys.stdout - server.serve_forever() - class HashEquivalenceTestSetup(object): METHOD = 'TestMethod' server_index = 0 - def start_server(self, dbpath=None, upstream=None, read_only=False): + def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc): self.server_index += 1 if dbpath is None: dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) - def cleanup_thread(thread): - thread.terminate() - thread.join() + def cleanup_server(server): + if server.process.exitcode is not None: + return + + server.process.terminate() + server.process.join() server = create_server(self.get_server_addr(self.server_index), dbpath, @@ -44,9 +48,8 @@ class HashEquivalenceTestSetup(object): read_only=read_only) server.dbpath = dbpath - server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) - server.thread.start() - self.addCleanup(cleanup_thread, server.thread) + server.serve_as_process(prefunc=prefunc, args=(self.server_index,)) + self.addCleanup(cleanup_server, server) def cleanup_client(client): client.close() @@ -283,6 +286,33 @@ class HashEquivalenceCommonTests(object): self.assertClientGetHash(self.client, taskhash2, None) + def test_slow_server_start(self): + """ + Ensures that the server will exit correctly even if it gets a SIGTERM + before entering the main loop + """ + + event = multiprocessing.Event() + + def prefunc(server, idx): + nonlocal event + server_prefunc(server, idx) + event.wait() + + def do_nothing(signum, frame): + pass + + old_signal = signal.signal(signal.SIGTERM, do_nothing) + self.addCleanup(signal.signal, signal.SIGTERM, old_signal) + + _, server = self.start_server(prefunc=prefunc) + server.process.terminate() + time.sleep(30) + event.set() + server.process.join(300) + self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!") + + class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): def get_server_addr(self, server_idx): return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx) |