aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/bb/asyncrpc/serv.py21
-rw-r--r--lib/bb/cooker.py3
-rw-r--r--lib/hashserv/tests.py54
3 files changed, 64 insertions, 14 deletions
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index ef20cb71d..4084f300d 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -9,6 +9,7 @@ import os
import signal
import socket
import sys
+import multiprocessing
from . import chunkify, DEFAULT_MAX_CHUNK
@@ -201,12 +202,14 @@ class AsyncServer(object):
pass
def signal_handler(self):
+ self.logger.debug("Got exit signal")
self.loop.stop()
def serve_forever(self):
asyncio.set_event_loop(self.loop)
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
self.run_loop_forever()
self.server.close()
@@ -221,3 +224,21 @@ class AsyncServer(object):
if self._cleanup_socket is not None:
self._cleanup_socket()
+
+ def serve_as_process(self, *, prefunc=None, args=()):
+ def run():
+ if prefunc is not None:
+ prefunc(self, *args)
+ self.serve_forever()
+
+ # Temporarily block SIGTERM. The server process will inherit this
+ # block which will ensure it doesn't receive the SIGTERM until the
+ # handler is ready for it
+ mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
+ try:
+ self.process = multiprocessing.Process(target=run)
+ self.process.start()
+
+ return self.process
+ finally:
+ signal.pthread_sigmask(signal.SIG_SETMASK, mask)
diff --git a/lib/bb/cooker.py b/lib/bb/cooker.py
index 39e10e613..b2d69c28c 100644
--- a/lib/bb/cooker.py
+++ b/lib/bb/cooker.py
@@ -390,8 +390,7 @@ class BBCooker:
dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db"
self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR")
self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False)
- self.hashserv.process = multiprocessing.Process(target=self.hashserv.serve_forever)
- self.hashserv.process.start()
+ self.hashserv.serve_as_process()
self.data.setVar("BB_HASHSERVE", self.hashservaddr)
self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr)
self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index e2b762dbf..e851535c5 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -15,28 +15,32 @@ import tempfile
import threading
import unittest
import socket
+import time
+import signal
-def _run_server(server, idx):
- # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
- # format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
+def server_prefunc(server, idx):
+ logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
+ format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
+ server.logger.debug("Running server %d" % idx)
sys.stdout = open('bbhashserv-%d.log' % idx, 'w')
sys.stderr = sys.stdout
- server.serve_forever()
-
class HashEquivalenceTestSetup(object):
METHOD = 'TestMethod'
server_index = 0
- def start_server(self, dbpath=None, upstream=None, read_only=False):
+ def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
self.server_index += 1
if dbpath is None:
dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
- def cleanup_thread(thread):
- thread.terminate()
- thread.join()
+ def cleanup_server(server):
+ if server.process.exitcode is not None:
+ return
+
+ server.process.terminate()
+ server.process.join()
server = create_server(self.get_server_addr(self.server_index),
dbpath,
@@ -44,9 +48,8 @@ class HashEquivalenceTestSetup(object):
read_only=read_only)
server.dbpath = dbpath
- server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index))
- server.thread.start()
- self.addCleanup(cleanup_thread, server.thread)
+ server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
+ self.addCleanup(cleanup_server, server)
def cleanup_client(client):
client.close()
@@ -283,6 +286,33 @@ class HashEquivalenceCommonTests(object):
self.assertClientGetHash(self.client, taskhash2, None)
+ def test_slow_server_start(self):
+ """
+ Ensures that the server will exit correctly even if it gets a SIGTERM
+ before entering the main loop
+ """
+
+ event = multiprocessing.Event()
+
+ def prefunc(server, idx):
+ nonlocal event
+ server_prefunc(server, idx)
+ event.wait()
+
+ def do_nothing(signum, frame):
+ pass
+
+ old_signal = signal.signal(signal.SIGTERM, do_nothing)
+ self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
+
+ _, server = self.start_server(prefunc=prefunc)
+ server.process.terminate()
+ time.sleep(30)
+ event.set()
+ server.process.join(300)
+ self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!")
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)