aboutsummaryrefslogtreecommitdiffstats
path: root/lib/hashserv
diff options
context:
space:
mode:
Diffstat (limited to 'lib/hashserv')
-rw-r--r--lib/hashserv/tests.py54
1 files changed, 42 insertions, 12 deletions
diff --git a/lib/hashserv/tests.py b/lib/hashserv/tests.py
index e2b762dbf..e851535c5 100644
--- a/lib/hashserv/tests.py
+++ b/lib/hashserv/tests.py
@@ -15,28 +15,32 @@ import tempfile
import threading
import unittest
import socket
+import time
+import signal
-def _run_server(server, idx):
- # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
- # format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
+def server_prefunc(server, idx):
+ logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
+ format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
+ server.logger.debug("Running server %d" % idx)
sys.stdout = open('bbhashserv-%d.log' % idx, 'w')
sys.stderr = sys.stdout
- server.serve_forever()
-
class HashEquivalenceTestSetup(object):
METHOD = 'TestMethod'
server_index = 0
- def start_server(self, dbpath=None, upstream=None, read_only=False):
+ def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
self.server_index += 1
if dbpath is None:
dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
- def cleanup_thread(thread):
- thread.terminate()
- thread.join()
+ def cleanup_server(server):
+ if server.process.exitcode is not None:
+ return
+
+ server.process.terminate()
+ server.process.join()
server = create_server(self.get_server_addr(self.server_index),
dbpath,
@@ -44,9 +48,8 @@ class HashEquivalenceTestSetup(object):
read_only=read_only)
server.dbpath = dbpath
- server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index))
- server.thread.start()
- self.addCleanup(cleanup_thread, server.thread)
+ server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
+ self.addCleanup(cleanup_server, server)
def cleanup_client(client):
client.close()
@@ -283,6 +286,33 @@ class HashEquivalenceCommonTests(object):
self.assertClientGetHash(self.client, taskhash2, None)
+ def test_slow_server_start(self):
+ """
+ Ensures that the server will exit correctly even if it gets a SIGTERM
+ before entering the main loop
+ """
+
+ event = multiprocessing.Event()
+
+ def prefunc(server, idx):
+ nonlocal event
+ server_prefunc(server, idx)
+ event.wait()
+
+ def do_nothing(signum, frame):
+ pass
+
+ old_signal = signal.signal(signal.SIGTERM, do_nothing)
+ self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
+
+ _, server = self.start_server(prefunc=prefunc)
+ server.process.terminate()
+ time.sleep(30)
+ event.set()
+ server.process.join(300)
+ self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!")
+
+
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)