aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJoshua Watt <jpewhacker@gmail.com>2021-07-22 11:19:37 -0500
committerRichard Purdie <richard.purdie@linuxfoundation.org>2021-07-27 09:27:31 +0100
commitef2865efa98ad20823267364f2159d8d8c931400 (patch)
tree688e576e9154375cd614f97c44ed0a1e25af77b3
parentb31f1853d7fcb8b8f8885ca513a0021a5d0301e6 (diff)
downloadbitbake-ef2865efa98ad20823267364f2159d8d8c931400.tar.gz
bitbake: asyncrpc: Catch early SIGTERM
If the SIGTERM signal is sent to an asyncrpc server before it has installed the SIGTERM handler in the main loop, it may miss the signal which will can cause the calling process to wait forever on the join(). To resolve this, the calling process should mask of SIGTERM before forking the server process and the server should unmask the signal only after the handler is installed. To simplify the usage of the server, an new helper function called serve_as_process() is added to do this automatically and correctly. Thanks: Scott Murray <scott.murray@konsulko.com> for helping debug Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
-rw-r--r--lib/bb/asyncrpc/serv.py21
-rw-r--r--lib/bb/cooker.py3
-rw-r--r--lib/hashserv/tests.py54
3 files changed, 64 insertions, 14 deletions
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index ef20cb71d..4084f300d 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -9,6 +9,7 @@ import os
import signal
import socket
import sys
+import multiprocessing
from . import chunkify, DEFAULT_MAX_CHUNK
@@ -201,12 +202,14 @@ class AsyncServer(object):
pass
def signal_handler(self):
+ self.logger.debug("Got exit signal")
self.loop.stop()
def serve_forever(self):
asyncio.set_event_loop(self.loop)
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
self.run_loop_forever()
self.server.close()
@@ -221,3 +224,21 @@ class AsyncServer(object):
if self._cleanup_socket is not None:
self._cleanup_socket()
+
+ def serve_as_process(self, *, prefunc=None, args=()):
+ def run():
+ if prefunc is not None:
+ prefunc(self, *args)
+ self.serve_forever()
+
+ # Temporarily block SIGTERM. The server process will inherit this
+ # block which will ensure it doesn't receive the SIGTERM until the
+ # handler is ready for it
+ mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
+ try:
+ self.process = multiprocessing.Process(target=run)
+ self.process.start()
+
+ return self.process
+ finally:
+ signal.pthread_sigmask(signal.SIG_SETMASK, mask)
diff --git a/lib/bb/cooker.py b/lib/bb/cooker.py
index 39e10e613..b2d69c28c 100644
--- a/lib/bb/cooker.py
+++ b/lib/bb/cooker.py
@@ -390,8 +390,7 @@ class BBCooker:
dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db"
self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR")
self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False)
- self.hashserv.process = multiprocessing.Process(target=self.hashserv.serve_forever)
- self.hashserv.process.start()
+ self.hashserv.serve_as_process()
self.data.setVar("BB_HASHSERVE", self.hashservaddr)
self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr)
self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr)
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index e2b762dbf..e851535c5 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -15,28 +15,32 @@ import tempfile
import threading
import unittest
import socket
+import time
+import signal
-def _run_server(server, idx):
- # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
- # format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
+def server_prefunc(server, idx):
+ logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
+ format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
+ server.logger.debug("Running server %d" % idx)
sys.stdout = open('bbhashserv-%d.log' % idx, 'w')
sys.stderr = sys.stdout
- server.serve_forever()
-
class HashEquivalenceTestSetup(object):
METHOD = 'TestMethod'
server_index = 0
- def start_server(self, dbpath=None, upstream=None, read_only=False):
+ def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
self.server_index += 1
if dbpath is None:
dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
- def cleanup_thread(thread):
- thread.terminate()
- thread.join()
+ def cleanup_server(server):
+ if server.process.exitcode is not None:
+ return
+
+ server.process.terminate()
+ server.process.join()
server = create_server(self.get_server_addr(self.server_index),
dbpath,
@@ -44,9 +48,8 @@ class HashEquivalenceTestSetup(object):
read_only=read_only)
server.dbpath = dbpath
- server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index))
- server.thread.start()
- self.addCleanup(cleanup_thread, server.thread)
+ server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
+ self.addCleanup(cleanup_server, server)
def cleanup_client(client):
client.close()
@@ -283,6 +286,33 @@ class HashEquivalenceCommonTests(object):
self.assertClientGetHash(self.client, taskhash2, None)
+ def test_slow_server_start(self):
+ """
+ Ensures that the server will exit correctly even if it gets a SIGTERM
+ before entering the main loop
+ """
+
+ event = multiprocessing.Event()
+
+ def prefunc(server, idx):
+ nonlocal event
+ server_prefunc(server, idx)
+ event.wait()
+
+ def do_nothing(signum, frame):
+ pass
+
+ old_signal = signal.signal(signal.SIGTERM, do_nothing)
+ self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
+
+ _, server = self.start_server(prefunc=prefunc)
+ server.process.terminate()
+ time.sleep(30)
+ event.set()
+ server.process.join(300)
+ self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!")
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)