diff options
Diffstat (limited to 'lib/prserv')
-rw-r--r-- | lib/prserv/__init__.py | 99 | ||||
-rw-r--r-- | lib/prserv/client.py | 72 | ||||
-rw-r--r-- | lib/prserv/db.py | 427 | ||||
-rw-r--r-- | lib/prserv/serv.py | 684 | ||||
-rw-r--r-- | lib/prserv/tests.py | 386 |
5 files changed, 1097 insertions, 571 deletions
diff --git a/lib/prserv/__init__.py b/lib/prserv/__init__.py index 9961040b5..a817b03c1 100644 --- a/lib/prserv/__init__.py +++ b/lib/prserv/__init__.py @@ -1,18 +1,95 @@ # +# Copyright BitBake Contributors +# # SPDX-License-Identifier: GPL-2.0-only # -__version__ = "1.0.0" -import os, time -import sys,logging +__version__ = "2.0.0" + +import logging +logger = logging.getLogger("BitBake.PRserv") + +from bb.asyncrpc.client import parse_address, ADDR_TYPE_UNIX, ADDR_TYPE_WS + +def create_server(addr, dbpath, upstream=None, read_only=False): + from . import serv + + s = serv.PRServer(dbpath, upstream=upstream, read_only=read_only) + host, port = addr.split(":") + s.start_tcp_server(host, int(port)) + + return s + +def increase_revision(ver): + """Take a revision string such as "1" or "1.2.3" or even a number and increase its last number + This fails if the last number is not an integer""" + + fields=str(ver).split('.') + last = fields[-1] + + try: + val = int(last) + except Exception as e: + logger.critical("Unable to increase revision value %s: %s" % (ver, e)) + raise e + + return ".".join(fields[0:-1] + list(str(val + 1))) + +def _revision_greater_or_equal(rev1, rev2): + """Compares x.y.z revision numbers, using integer comparison + Returns True if rev1 is greater or equal to rev2""" + + fields1 = rev1.split(".") + fields2 = rev2.split(".") + l1 = len(fields1) + l2 = len(fields2) + + for i in range(l1): + val1 = int(fields1[i]) + if i < l2: + val2 = int(fields2[i]) + if val2 < val1: + return True + elif val2 > val1: + return False + else: + return True + return True + +def revision_smaller(rev1, rev2): + """Compares x.y.z revision numbers, using integer comparison + Returns True if rev1 is strictly smaller than rev2""" + return not(_revision_greater_or_equal(rev1, rev2)) + +def revision_greater(rev1, rev2): + """Compares x.y.z revision numbers, using integer comparison + Returns True if rev1 is strictly greater than rev2""" + return _revision_greater_or_equal(rev1, rev2) and (rev1 != rev2) + +def create_client(addr): + from . import client + + c = client.PRClient() + + try: + (typ, a) = parse_address(addr) + c.connect_tcp(*a) + return c + except Exception as e: + c.close() + raise e + +async def create_async_client(addr): + from . import client + + c = client.PRAsyncClient() -def init_logger(logfile, loglevel): - numeric_level = getattr(logging, loglevel.upper(), None) - if not isinstance(numeric_level, int): - raise ValueError('Invalid log level: %s' % loglevel) - FORMAT = '%(asctime)-15s %(message)s' - logging.basicConfig(level=numeric_level, filename=logfile, format=FORMAT) + try: + (typ, a) = parse_address(addr) + await c.connect_tcp(*a) + return c -class NotFoundError(Exception): - pass + except Exception as e: + await c.close() + raise e diff --git a/lib/prserv/client.py b/lib/prserv/client.py new file mode 100644 index 000000000..9f5794c43 --- /dev/null +++ b/lib/prserv/client.py @@ -0,0 +1,72 @@ +# +# Copyright BitBake Contributors +# +# SPDX-License-Identifier: GPL-2.0-only +# + +import logging +import bb.asyncrpc +from . import create_async_client + +logger = logging.getLogger("BitBake.PRserv") + +class PRAsyncClient(bb.asyncrpc.AsyncClient): + def __init__(self): + super().__init__("PRSERVICE", "1.0", logger) + + async def getPR(self, version, pkgarch, checksum, history=False): + response = await self.invoke( + {"get-pr": {"version": version, "pkgarch": pkgarch, "checksum": checksum, "history": history}} + ) + if response: + return response["value"] + + async def test_pr(self, version, pkgarch, checksum, history=False): + response = await self.invoke( + {"test-pr": {"version": version, "pkgarch": pkgarch, "checksum": checksum, "history": history}} + ) + if response: + return response["value"] + + async def test_package(self, version, pkgarch): + response = await self.invoke( + {"test-package": {"version": version, "pkgarch": pkgarch}} + ) + if response: + return response["value"] + + async def max_package_pr(self, version, pkgarch): + response = await self.invoke( + {"max-package-pr": {"version": version, "pkgarch": pkgarch}} + ) + if response: + return response["value"] + + async def importone(self, version, pkgarch, checksum, value): + response = await self.invoke( + {"import-one": {"version": version, "pkgarch": pkgarch, "checksum": checksum, "value": value}} + ) + if response: + return response["value"] + + async def export(self, version, pkgarch, checksum, colinfo, history=False): + response = await self.invoke( + {"export": {"version": version, "pkgarch": pkgarch, "checksum": checksum, "colinfo": colinfo, "history": history}} + ) + if response: + return (response["metainfo"], response["datainfo"]) + + async def is_readonly(self): + response = await self.invoke( + {"is-readonly": {}} + ) + if response: + return response["readonly"] + +class PRClient(bb.asyncrpc.Client): + def __init__(self): + super().__init__() + self._add_methods("getPR", "test_pr", "test_package", "max_package_pr", "importone", "export", "is_readonly") + + def _get_async_client(self): + return PRAsyncClient() diff --git a/lib/prserv/db.py b/lib/prserv/db.py index cb2a2461e..2da493ddf 100644 --- a/lib/prserv/db.py +++ b/lib/prserv/db.py @@ -1,4 +1,6 @@ # +# Copyright BitBake Contributors +# # SPDX-License-Identifier: GPL-2.0-only # @@ -6,19 +8,13 @@ import logging import os.path import errno import prserv -import time +import sqlite3 -try: - import sqlite3 -except ImportError: - from pysqlite2 import dbapi2 as sqlite3 +from contextlib import closing +from . import increase_revision, revision_greater, revision_smaller logger = logging.getLogger("BitBake.PRserv") -sqlversion = sqlite3.sqlite_version_info -if sqlversion[0] < 3 or (sqlversion[0] == 3 and sqlversion[1] < 3): - raise Exception("sqlite3 version 3.3.0 or later is required.") - # # "No History" mode - for a given query tuple (version, pkgarch, checksum), # the returned value will be the largest among all the values of the same @@ -27,212 +23,232 @@ if sqlversion[0] < 3 or (sqlversion[0] == 3 and sqlversion[1] < 3): # "History" mode - Return a new higher value for previously unseen query # tuple (version, pkgarch, checksum), otherwise return historical value. # Value can decrement if returning to a previous build. -# class PRTable(object): - def __init__(self, conn, table, nohist): + def __init__(self, conn, table, read_only): self.conn = conn - self.nohist = nohist - self.dirty = False - if nohist: - self.table = "%s_nohist" % table - else: - self.table = "%s_hist" % table - - self._execute("CREATE TABLE IF NOT EXISTS %s \ - (version TEXT NOT NULL, \ - pkgarch TEXT NOT NULL, \ - checksum TEXT NOT NULL, \ - value INTEGER, \ - PRIMARY KEY (version, pkgarch, checksum));" % self.table) - - def _execute(self, *query): - """Execute a query, waiting to acquire a lock if necessary""" - start = time.time() - end = start + 20 - while True: - try: - return self.conn.execute(*query) - except sqlite3.OperationalError as exc: - if 'is locked' in str(exc) and end > time.time(): - continue - raise exc - - def sync(self): - self.conn.commit() - self._execute("BEGIN EXCLUSIVE TRANSACTION") - - def sync_if_dirty(self): - if self.dirty: - self.sync() - self.dirty = False - - def _getValueHist(self, version, pkgarch, checksum): - data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, - (version, pkgarch, checksum)) - row=data.fetchone() - if row is not None: - return row[0] - else: - #no value found, try to insert - try: - self._execute("INSERT INTO %s VALUES (?, ?, ?, (select ifnull(max(value)+1,0) from %s where version=? AND pkgarch=?));" - % (self.table,self.table), - (version,pkgarch, checksum,version, pkgarch)) - except sqlite3.IntegrityError as exc: - logger.error(str(exc)) - - self.dirty = True - - data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, - (version, pkgarch, checksum)) + self.read_only = read_only + self.table = table + + # Creating the table even if the server is read-only. + # This avoids a race condition if a shared database + # is accessed by a read-only server first. + + with closing(self.conn.cursor()) as cursor: + cursor.execute("CREATE TABLE IF NOT EXISTS %s \ + (version TEXT NOT NULL, \ + pkgarch TEXT NOT NULL, \ + checksum TEXT NOT NULL, \ + value TEXT, \ + PRIMARY KEY (version, pkgarch, checksum, value));" % self.table) + self.conn.commit() + + def _extremum_value(self, rows, is_max): + value = None + + for row in rows: + current_value = row[0] + if value is None: + value = current_value + else: + if is_max: + is_new_extremum = revision_greater(current_value, value) + else: + is_new_extremum = revision_smaller(current_value, value) + if is_new_extremum: + value = current_value + return value + + def _max_value(self, rows): + return self._extremum_value(rows, True) + + def _min_value(self, rows): + return self._extremum_value(rows, False) + + def test_package(self, version, pkgarch): + """Returns whether the specified package version is found in the database for the specified architecture""" + + # Just returns the value if found or None otherwise + with closing(self.conn.cursor()) as cursor: + data=cursor.execute("SELECT value FROM %s WHERE version=? AND pkgarch=?;" % self.table, + (version, pkgarch)) row=data.fetchone() if row is not None: - return row[0] + return True else: - raise prserv.NotFoundError - - def _getValueNohist(self, version, pkgarch, checksum): - data=self._execute("SELECT value FROM %s \ - WHERE version=? AND pkgarch=? AND checksum=? AND \ - value >= (select max(value) from %s where version=? AND pkgarch=?);" - % (self.table, self.table), - (version, pkgarch, checksum, version, pkgarch)) - row=data.fetchone() - if row is not None: - return row[0] - else: - #no value found, try to insert - try: - self._execute("INSERT OR REPLACE INTO %s VALUES (?, ?, ?, (select ifnull(max(value)+1,0) from %s where version=? AND pkgarch=?));" - % (self.table,self.table), - (version, pkgarch, checksum, version, pkgarch)) - except sqlite3.IntegrityError as exc: - logger.error(str(exc)) - self.conn.rollback() - - self.dirty = True - - data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, - (version, pkgarch, checksum)) + return False + + def test_checksum_value(self, version, pkgarch, checksum, value): + """Returns whether the specified value is found in the database for the specified package, architecture and checksum""" + + with closing(self.conn.cursor()) as cursor: + data=cursor.execute("SELECT value FROM %s WHERE version=? AND pkgarch=? and checksum=? and value=?;" % self.table, + (version, pkgarch, checksum, value)) row=data.fetchone() if row is not None: - return row[0] + return True else: - raise prserv.NotFoundError + return False - def getValue(self, version, pkgarch, checksum): - if self.nohist: - return self._getValueNohist(version, pkgarch, checksum) - else: - return self._getValueHist(version, pkgarch, checksum) - - def _importHist(self, version, pkgarch, checksum, value): - val = None - data = self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, - (version, pkgarch, checksum)) - row = data.fetchone() - if row is not None: - val=row[0] + def test_value(self, version, pkgarch, value): + """Returns whether the specified value is found in the database for the specified package and architecture""" + + # Just returns the value if found or None otherwise + with closing(self.conn.cursor()) as cursor: + data=cursor.execute("SELECT value FROM %s WHERE version=? AND pkgarch=? and value=?;" % self.table, + (version, pkgarch, value)) + row=data.fetchone() + if row is not None: + return True + else: + return False + + + def find_package_max_value(self, version, pkgarch): + """Returns the greatest value for (version, pkgarch), or None if not found. Doesn't create a new value""" + + with closing(self.conn.cursor()) as cursor: + data = cursor.execute("SELECT value FROM %s where version=? AND pkgarch=?;" % (self.table), + (version, pkgarch)) + rows = data.fetchall() + value = self._max_value(rows) + return value + + def find_value(self, version, pkgarch, checksum, history=False): + """Returns the value for the specified checksum if found or None otherwise.""" + + if history: + return self.find_min_value(version, pkgarch, checksum) else: - #no value found, try to insert - try: - self._execute("INSERT INTO %s VALUES (?, ?, ?, ?);" % (self.table), + return self.find_max_value(version, pkgarch, checksum) + + + def _find_extremum_value(self, version, pkgarch, checksum, is_max): + """Returns the maximum (if is_max is True) or minimum (if is_max is False) value + for (version, pkgarch, checksum), or None if not found. Doesn't create a new value""" + + with closing(self.conn.cursor()) as cursor: + data = cursor.execute("SELECT value FROM %s where version=? AND pkgarch=? AND checksum=?;" % (self.table), + (version, pkgarch, checksum)) + rows = data.fetchall() + return self._extremum_value(rows, is_max) + + def find_max_value(self, version, pkgarch, checksum): + return self._find_extremum_value(version, pkgarch, checksum, True) + + def find_min_value(self, version, pkgarch, checksum): + return self._find_extremum_value(version, pkgarch, checksum, False) + + def find_new_subvalue(self, version, pkgarch, base): + """Take and increase the greatest "<base>.y" value for (version, pkgarch), or return "<base>.0" if not found. + This doesn't store a new value.""" + + with closing(self.conn.cursor()) as cursor: + data = cursor.execute("SELECT value FROM %s where version=? AND pkgarch=? AND value LIKE '%s.%%';" % (self.table, base), + (version, pkgarch)) + rows = data.fetchall() + value = self._max_value(rows) + + if value is not None: + return increase_revision(value) + else: + return base + ".0" + + def store_value(self, version, pkgarch, checksum, value): + """Store value in the database""" + + if not self.read_only and not self.test_checksum_value(version, pkgarch, checksum, value): + with closing(self.conn.cursor()) as cursor: + cursor.execute("INSERT INTO %s VALUES (?, ?, ?, ?);" % (self.table), (version, pkgarch, checksum, value)) - except sqlite3.IntegrityError as exc: - logger.error(str(exc)) + self.conn.commit() - self.dirty = True + def _get_value(self, version, pkgarch, checksum, history): - data = self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, - (version, pkgarch, checksum)) - row = data.fetchone() - if row is not None: - val = row[0] - return val + max_value = self.find_package_max_value(version, pkgarch) - def _importNohist(self, version, pkgarch, checksum, value): - try: - #try to insert - self._execute("INSERT INTO %s VALUES (?, ?, ?, ?);" % (self.table), - (version, pkgarch, checksum,value)) - except sqlite3.IntegrityError as exc: - #already have the record, try to update - try: - self._execute("UPDATE %s SET value=? WHERE version=? AND pkgarch=? AND checksum=? AND value<?" - % (self.table), - (value,version,pkgarch,checksum,value)) - except sqlite3.IntegrityError as exc: - logger.error(str(exc)) - - self.dirty = True - - data = self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=? AND value>=?;" % self.table, - (version,pkgarch,checksum,value)) - row=data.fetchone() - if row is not None: - return row[0] + if max_value is None: + # version, pkgarch completely unknown. Return initial value. + return "0" + + value = self.find_value(version, pkgarch, checksum, history) + + if value is None: + # version, pkgarch found but not checksum. Create a new value from the maximum one + return increase_revision(max_value) + + if history: + return value + + # "no history" mode - If the value is not the maximum value for the package, need to increase it. + if max_value > value: + return increase_revision(max_value) else: - return None + return value + + def get_value(self, version, pkgarch, checksum, history): + value = self._get_value(version, pkgarch, checksum, history) + if not self.read_only: + self.store_value(version, pkgarch, checksum, value) + return value def importone(self, version, pkgarch, checksum, value): - if self.nohist: - return self._importNohist(version, pkgarch, checksum, value) - else: - return self._importHist(version, pkgarch, checksum, value) + self.store_value(version, pkgarch, checksum, value) + return value - def export(self, version, pkgarch, checksum, colinfo): + def export(self, version, pkgarch, checksum, colinfo, history=False): metainfo = {} - #column info - if colinfo: - metainfo['tbl_name'] = self.table - metainfo['core_ver'] = prserv.__version__ - metainfo['col_info'] = [] - data = self._execute("PRAGMA table_info(%s);" % self.table) + with closing(self.conn.cursor()) as cursor: + #column info + if colinfo: + metainfo["tbl_name"] = self.table + metainfo["core_ver"] = prserv.__version__ + metainfo["col_info"] = [] + data = cursor.execute("PRAGMA table_info(%s);" % self.table) + for row in data: + col = {} + col["name"] = row["name"] + col["type"] = row["type"] + col["notnull"] = row["notnull"] + col["dflt_value"] = row["dflt_value"] + col["pk"] = row["pk"] + metainfo["col_info"].append(col) + + #data info + datainfo = [] + + if history: + sqlstmt = "SELECT * FROM %s as T1 WHERE 1=1 " % self.table + else: + sqlstmt = "SELECT T1.version, T1.pkgarch, T1.checksum, T1.value FROM %s as T1, \ + (SELECT version, pkgarch, max(value) as maxvalue FROM %s GROUP BY version, pkgarch) as T2 \ + WHERE T1.version=T2.version AND T1.pkgarch=T2.pkgarch AND T1.value=T2.maxvalue " % (self.table, self.table) + sqlarg = [] + where = "" + if version: + where += "AND T1.version=? " + sqlarg.append(str(version)) + if pkgarch: + where += "AND T1.pkgarch=? " + sqlarg.append(str(pkgarch)) + if checksum: + where += "AND T1.checksum=? " + sqlarg.append(str(checksum)) + + sqlstmt += where + ";" + + if len(sqlarg): + data = cursor.execute(sqlstmt, tuple(sqlarg)) + else: + data = cursor.execute(sqlstmt) for row in data: - col = {} - col['name'] = row['name'] - col['type'] = row['type'] - col['notnull'] = row['notnull'] - col['dflt_value'] = row['dflt_value'] - col['pk'] = row['pk'] - metainfo['col_info'].append(col) - - #data info - datainfo = [] - - if self.nohist: - sqlstmt = "SELECT T1.version, T1.pkgarch, T1.checksum, T1.value FROM %s as T1, \ - (SELECT version,pkgarch,max(value) as maxvalue FROM %s GROUP BY version,pkgarch) as T2 \ - WHERE T1.version=T2.version AND T1.pkgarch=T2.pkgarch AND T1.value=T2.maxvalue " % (self.table, self.table) - else: - sqlstmt = "SELECT * FROM %s as T1 WHERE 1=1 " % self.table - sqlarg = [] - where = "" - if version: - where += "AND T1.version=? " - sqlarg.append(str(version)) - if pkgarch: - where += "AND T1.pkgarch=? " - sqlarg.append(str(pkgarch)) - if checksum: - where += "AND T1.checksum=? " - sqlarg.append(str(checksum)) - - sqlstmt += where + ";" - - if len(sqlarg): - data = self._execute(sqlstmt, tuple(sqlarg)) - else: - data = self._execute(sqlstmt) - for row in data: - if row['version']: - col = {} - col['version'] = row['version'] - col['pkgarch'] = row['pkgarch'] - col['checksum'] = row['checksum'] - col['value'] = row['value'] - datainfo.append(col) + if row["version"]: + col = {} + col["version"] = row["version"] + col["pkgarch"] = row["pkgarch"] + col["checksum"] = row["checksum"] + col["value"] = row["value"] + datainfo.append(col) return (metainfo, datainfo) def dump_db(self, fd): @@ -240,41 +256,46 @@ class PRTable(object): for line in self.conn.iterdump(): writeCount = writeCount + len(line) + 1 fd.write(line) - fd.write('\n') + fd.write("\n") return writeCount class PRData(object): """Object representing the PR database""" - def __init__(self, filename, nohist=True): + def __init__(self, filename, read_only=False): self.filename=os.path.abspath(filename) - self.nohist=nohist + self.read_only = read_only #build directory hierarchy try: os.makedirs(os.path.dirname(self.filename)) except OSError as e: if e.errno != errno.EEXIST: raise e - self.connection=sqlite3.connect(self.filename, isolation_level="EXCLUSIVE", check_same_thread = False) + uri = "file:%s%s" % (self.filename, "?mode=ro" if self.read_only else "") + logger.debug("Opening PRServ database '%s'" % (uri)) + self.connection=sqlite3.connect(uri, uri=True) self.connection.row_factory=sqlite3.Row - self.connection.execute("pragma synchronous = off;") - self.connection.execute("PRAGMA journal_mode = MEMORY;") + self.connection.execute("PRAGMA synchronous = OFF;") + self.connection.execute("PRAGMA journal_mode = WAL;") + self.connection.commit() self._tables={} def disconnect(self): + self.connection.commit() self.connection.close() - def __getitem__(self,tblname): + def __getitem__(self, tblname): if not isinstance(tblname, str): raise TypeError("tblname argument must be a string, not '%s'" % type(tblname)) if tblname in self._tables: return self._tables[tblname] else: - tableobj = self._tables[tblname] = PRTable(self.connection, tblname, self.nohist) + tableobj = self._tables[tblname] = PRTable(self.connection, tblname, self.read_only) return tableobj def __delitem__(self, tblname): if tblname in self._tables: del self._tables[tblname] logger.info("drop table %s" % (tblname)) - self.connection.execute("DROP TABLE IF EXISTS %s;" % tblname) + self.connection.execute("DROP TABLE IF EXISTS %s;" % tblname) + self.connection.commit() diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py index 25dcf8a0e..e17588630 100644 --- a/lib/prserv/serv.py +++ b/lib/prserv/serv.py @@ -1,354 +1,326 @@ # +# Copyright BitBake Contributors +# # SPDX-License-Identifier: GPL-2.0-only # import os,sys,logging import signal, time -from xmlrpc.server import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler -import threading -import queue import socket import io import sqlite3 -import bb.server.xmlrpcclient import prserv import prserv.db import errno -import select +from . import create_async_client, revision_smaller, increase_revision +import bb.asyncrpc logger = logging.getLogger("BitBake.PRserv") -if sys.hexversion < 0x020600F0: - print("Sorry, python 2.6 or later is required.") - sys.exit(1) +PIDPREFIX = "/tmp/PRServer_%s_%s.pid" +singleton = None -class Handler(SimpleXMLRPCRequestHandler): - def _dispatch(self,method,params): +class PRServerClient(bb.asyncrpc.AsyncServerConnection): + def __init__(self, socket, server): + super().__init__(socket, "PRSERVICE", server.logger) + self.server = server + + self.handlers.update({ + "get-pr": self.handle_get_pr, + "test-pr": self.handle_test_pr, + "test-package": self.handle_test_package, + "max-package-pr": self.handle_max_package_pr, + "import-one": self.handle_import_one, + "export": self.handle_export, + "is-readonly": self.handle_is_readonly, + }) + + def validate_proto_version(self): + return (self.proto_version == (1, 0)) + + async def dispatch_message(self, msg): try: - value=self.server.funcs[method](*params) + return await super().dispatch_message(msg) except: - import traceback - traceback.print_exc() raise - return value -PIDPREFIX = "/tmp/PRServer_%s_%s.pid" -singleton = None + async def handle_test_pr(self, request): + '''Finds the PR value corresponding to the request. If not found, returns None and doesn't insert a new value''' + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + history = request["history"] + value = self.server.table.find_value(version, pkgarch, checksum, history) + return {"value": value} -class PRServer(SimpleXMLRPCServer): - def __init__(self, dbfile, logfile, interface, daemon=True): - ''' constructor ''' - try: - SimpleXMLRPCServer.__init__(self, interface, - logRequests=False, allow_none=True) - except socket.error: - ip=socket.gethostbyname(interface[0]) - port=interface[1] - msg="PR Server unable to bind to %s:%s\n" % (ip, port) - sys.stderr.write(msg) - raise PRServiceConfigError + async def handle_test_package(self, request): + '''Tells whether there are entries for (version, pkgarch) in the db. Returns True or False''' + version = request["version"] + pkgarch = request["pkgarch"] - self.dbfile=dbfile - self.daemon=daemon - self.logfile=logfile - self.working_thread=None - self.host, self.port = self.socket.getsockname() - self.pidfile=PIDPREFIX % (self.host, self.port) - - self.register_function(self.getPR, "getPR") - self.register_function(self.quit, "quit") - self.register_function(self.ping, "ping") - self.register_function(self.export, "export") - self.register_function(self.dump_db, "dump_db") - self.register_function(self.importone, "importone") - self.register_introspection_functions() - - self.quitpipein, self.quitpipeout = os.pipe() - - self.requestqueue = queue.Queue() - self.handlerthread = threading.Thread(target = self.process_request_thread) - self.handlerthread.daemon = False - - def process_request_thread(self): - """Same as in BaseServer but as a thread. - - In addition, exception handling is done here. - - """ - iter_count = 1 - # 60 iterations between syncs or sync if dirty every ~30 seconds - iterations_between_sync = 60 - - bb.utils.set_process_name("PRServ Handler") - - while not self.quitflag: - try: - (request, client_address) = self.requestqueue.get(True, 30) - except queue.Empty: - self.table.sync_if_dirty() - continue - if request is None: - continue - try: - self.finish_request(request, client_address) - self.shutdown_request(request) - iter_count = (iter_count + 1) % iterations_between_sync - if iter_count == 0: - self.table.sync_if_dirty() - except: - self.handle_error(request, client_address) - self.shutdown_request(request) - self.table.sync() - self.table.sync_if_dirty() - - def sigint_handler(self, signum, stack): - if self.table: - self.table.sync() - - def sigterm_handler(self, signum, stack): - if self.table: - self.table.sync() - self.quit() - self.requestqueue.put((None, None)) - - def process_request(self, request, client_address): - self.requestqueue.put((request, client_address)) - - def export(self, version=None, pkgarch=None, checksum=None, colinfo=True): - try: - return self.table.export(version, pkgarch, checksum, colinfo) - except sqlite3.Error as exc: - logger.error(str(exc)) - return None - - def dump_db(self): - """ - Returns a script (string) that reconstructs the state of the - entire database at the time this function is called. The script - language is defined by the backing database engine, which is a - function of server configuration. - Returns None if the database engine does not support dumping to - script or if some other error is encountered in processing. - """ - buff = io.StringIO() - try: - self.table.sync() - self.table.dump_db(buff) - return buff.getvalue() - except Exception as exc: - logger.error(str(exc)) - return None - finally: - buff.close() + value = self.server.table.test_package(version, pkgarch) + return {"value": value} - def importone(self, version, pkgarch, checksum, value): - return self.table.importone(version, pkgarch, checksum, value) + async def handle_max_package_pr(self, request): + '''Finds the greatest PR value for (version, pkgarch) in the db. Returns None if no entry was found''' + version = request["version"] + pkgarch = request["pkgarch"] - def ping(self): - return not self.quitflag + value = self.server.table.find_package_max_value(version, pkgarch) + return {"value": value} - def getinfo(self): - return (self.host, self.port) + async def handle_get_pr(self, request): + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + history = request["history"] - def getPR(self, version, pkgarch, checksum): - try: - return self.table.getValue(version, pkgarch, checksum) - except prserv.NotFoundError: - logger.error("can not find value for (%s, %s)",version, checksum) - return None - except sqlite3.Error as exc: - logger.error(str(exc)) - return None - - def quit(self): - self.quitflag=True - os.write(self.quitpipeout, b"q") - os.close(self.quitpipeout) - return - - def work_forever(self,): - self.quitflag = False - # This timeout applies to the poll in TCPServer, we need the select - # below to wake on our quit pipe closing. We only ever call into handle_request - # if there is data there. - self.timeout = 0.01 - - bb.utils.set_process_name("PRServ") - - # DB connection must be created after all forks - self.db = prserv.db.PRData(self.dbfile) - self.table = self.db["PRMAIN"] + if self.upstream_client is None: + value = self.server.table.get_value(version, pkgarch, checksum, history) + return {"value": value} - logger.info("Started PRServer with DBfile: %s, IP: %s, PORT: %s, PID: %s" % - (self.dbfile, self.host, self.port, str(os.getpid()))) - - self.handlerthread.start() - while not self.quitflag: - ready = select.select([self.fileno(), self.quitpipein], [], [], 30) - if self.quitflag: - break - if self.fileno() in ready[0]: - self.handle_request() - self.handlerthread.join() - self.db.disconnect() - logger.info("PRServer: stopping...") - self.server_close() - os.close(self.quitpipein) - return + # We have an upstream server. + # Check whether the local server already knows the requested configuration. + # If the configuration is a new one, the generated value we will add will + # depend on what's on the upstream server. That's why we're calling find_value() + # instead of get_value() directly. - def start(self): - if self.daemon: - pid = self.daemonize() - else: - pid = self.fork() - self.pid = pid + value = self.server.table.find_value(version, pkgarch, checksum, history) + upstream_max = await self.upstream_client.max_package_pr(version, pkgarch) - # Ensure both the parent sees this and the child from the work_forever log entry above - logger.info("Started PRServer with DBfile: %s, IP: %s, PORT: %s, PID: %s" % - (self.dbfile, self.host, self.port, str(pid))) + if value is not None: - def delpid(self): - os.remove(self.pidfile) + # The configuration is already known locally. - def daemonize(self): - """ - See Advanced Programming in the UNIX, Sec 13.3 - """ - try: - pid = os.fork() - if pid > 0: - os.waitpid(pid, 0) - #parent return instead of exit to give control - return pid - except OSError as e: - raise Exception("%s [%d]" % (e.strerror, e.errno)) - - os.setsid() - """ - fork again to make sure the daemon is not session leader, - which prevents it from acquiring controlling terminal - """ - try: - pid = os.fork() - if pid > 0: #parent - os._exit(0) - except OSError as e: - raise Exception("%s [%d]" % (e.strerror, e.errno)) + if history: + value = self.server.table.get_value(version, pkgarch, checksum, history) + else: + existing_value = value + # In "no history", we need to make sure the value doesn't decrease + # and is at least greater than the maximum upstream value + # and the maximum local value - self.cleanup_handles() - os._exit(0) + local_max = self.server.table.find_package_max_value(version, pkgarch) + if revision_smaller(value, local_max): + value = increase_revision(local_max) + + if revision_smaller(value, upstream_max): + # Ask upstream whether it knows the checksum + upstream_value = await self.upstream_client.test_pr(version, pkgarch, checksum) + if upstream_value is None: + # Upstream doesn't have our checksum, let create a new one + value = upstream_max + ".0" + else: + # Fine to take the same value as upstream + value = upstream_max + + if not value == existing_value and not self.server.read_only: + self.server.table.store_value(version, pkgarch, checksum, value) + + return {"value": value} + + # The configuration is a new one for the local server + # Let's ask the upstream server whether it knows it + + known_upstream = await self.upstream_client.test_package(version, pkgarch) + + if not known_upstream: + + # The package is not known upstream, must be a local-only package + # Let's compute the PR number using the local-only method + + value = self.server.table.get_value(version, pkgarch, checksum, history) + return {"value": value} + + # The package is known upstream, let's ask the upstream server + # whether it knows our new output hash + + value = await self.upstream_client.test_pr(version, pkgarch, checksum) + + if value is not None: + + # Upstream knows this output hash, let's store it and use it too. + + if not self.server.read_only: + self.server.table.store_value(version, pkgarch, checksum, value) + # If the local server is read only, won't be able to store the new + # value in the database and will have to keep asking the upstream server + return {"value": value} + + # The output hash doesn't exist upstream, get the most recent number from upstream (x) + # Then, we want to have a new PR value for the local server: x.y + + upstream_max = await self.upstream_client.max_package_pr(version, pkgarch) + # Here we know that the package is known upstream, so upstream_max can't be None + subvalue = self.server.table.find_new_subvalue(version, pkgarch, upstream_max) + + if not self.server.read_only: + self.server.table.store_value(version, pkgarch, checksum, subvalue) + + return {"value": subvalue} + + async def process_requests(self): + if self.server.upstream is not None: + self.upstream_client = await create_async_client(self.server.upstream) + else: + self.upstream_client = None - def fork(self): - try: - pid = os.fork() - if pid > 0: - self.socket.close() # avoid ResourceWarning in parent - return pid - except OSError as e: - raise Exception("%s [%d]" % (e.strerror, e.errno)) - - bb.utils.signal_on_parent_exit("SIGTERM") - self.cleanup_handles() - os._exit(0) - - def cleanup_handles(self): - signal.signal(signal.SIGINT, self.sigint_handler) - signal.signal(signal.SIGTERM, self.sigterm_handler) - os.chdir("/") - - sys.stdout.flush() - sys.stderr.flush() - - # We could be called from a python thread with io.StringIO as - # stdout/stderr or it could be 'real' unix fd forking where we need - # to physically close the fds to prevent the program launching us from - # potentially hanging on a pipe. Handle both cases. - si = open('/dev/null', 'r') - try: - os.dup2(si.fileno(),sys.stdin.fileno()) - except (AttributeError, io.UnsupportedOperation): - sys.stdin = si - so = open(self.logfile, 'a+') try: - os.dup2(so.fileno(),sys.stdout.fileno()) - except (AttributeError, io.UnsupportedOperation): - sys.stdout = so + await super().process_requests() + finally: + if self.upstream_client is not None: + await self.upstream_client.close() + + async def handle_import_one(self, request): + response = None + if not self.server.read_only: + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + value = request["value"] + + value = self.server.table.importone(version, pkgarch, checksum, value) + if value is not None: + response = {"value": value} + + return response + + async def handle_export(self, request): + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + colinfo = request["colinfo"] + history = request["history"] + try: - os.dup2(so.fileno(),sys.stderr.fileno()) - except (AttributeError, io.UnsupportedOperation): - sys.stderr = so - - # Clear out all log handlers prior to the fork() to avoid calling - # event handlers not part of the PRserver - for logger_iter in logging.Logger.manager.loggerDict.keys(): - logging.getLogger(logger_iter).handlers = [] - - # Ensure logging makes it to the logfile - streamhandler = logging.StreamHandler() - streamhandler.setLevel(logging.DEBUG) - formatter = bb.msg.BBLogFormatter("%(levelname)s: %(message)s") - streamhandler.setFormatter(formatter) - logger.addHandler(streamhandler) - - # write pidfile - pid = str(os.getpid()) - with open(self.pidfile, 'w') as pf: - pf.write("%s\n" % pid) - - self.work_forever() - self.delpid() + (metainfo, datainfo) = self.server.table.export(version, pkgarch, checksum, colinfo, history) + except sqlite3.Error as exc: + self.logger.error(str(exc)) + metainfo = datainfo = None -class PRServSingleton(object): - def __init__(self, dbfile, logfile, interface): + return {"metainfo": metainfo, "datainfo": datainfo} + + async def handle_is_readonly(self, request): + return {"readonly": self.server.read_only} + +class PRServer(bb.asyncrpc.AsyncServer): + def __init__(self, dbfile, read_only=False, upstream=None): + super().__init__(logger) self.dbfile = dbfile - self.logfile = logfile - self.interface = interface - self.host = None - self.port = None + self.table = None + self.read_only = read_only + self.upstream = upstream + + def accept_client(self, socket): + return PRServerClient(socket, self) def start(self): - self.prserv = PRServer(self.dbfile, self.logfile, self.interface, daemon=False) - self.prserv.start() - self.host, self.port = self.prserv.getinfo() + tasks = super().start() + self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only) + self.table = self.db["PRMAIN"] + + self.logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" % + (self.dbfile, self.address, str(os.getpid()))) - def getinfo(self): - return (self.host, self.port) + if self.upstream is not None: + self.logger.info("And upstream PRServer: %s " % (self.upstream)) -class PRServerConnection(object): - def __init__(self, host, port): - if is_local_special(host, port): - host, port = singleton.getinfo() + return tasks + + async def stop(self): + self.db.disconnect() + await super().stop() + +class PRServSingleton(object): + def __init__(self, dbfile, logfile, host, port, upstream): + self.dbfile = dbfile + self.logfile = logfile self.host = host self.port = port - self.connection, self.transport = bb.server.xmlrpcclient._create_server(self.host, self.port) - - def terminate(self): - try: - logger.info("Terminating PRServer...") - self.connection.quit() - except Exception as exc: - sys.stderr.write("%s\n" % str(exc)) + self.upstream = upstream - def getPR(self, version, pkgarch, checksum): - return self.connection.getPR(version, pkgarch, checksum) + def start(self): + self.prserv = PRServer(self.dbfile, upstream=self.upstream) + self.prserv.start_tcp_server(socket.gethostbyname(self.host), self.port) + self.process = self.prserv.serve_as_process(log_level=logging.WARNING) - def ping(self): - return self.connection.ping() + if not self.prserv.address: + raise PRServiceConfigError + if not self.port: + self.port = int(self.prserv.address.rsplit(":", 1)[1]) - def export(self,version=None, pkgarch=None, checksum=None, colinfo=True): - return self.connection.export(version, pkgarch, checksum, colinfo) +def run_as_daemon(func, pidfile, logfile): + """ + See Advanced Programming in the UNIX, Sec 13.3 + """ + try: + pid = os.fork() + if pid > 0: + os.waitpid(pid, 0) + #parent return instead of exit to give control + return pid + except OSError as e: + raise Exception("%s [%d]" % (e.strerror, e.errno)) - def dump_db(self): - return self.connection.dump_db() + os.setsid() + """ + fork again to make sure the daemon is not session leader, + which prevents it from acquiring controlling terminal + """ + try: + pid = os.fork() + if pid > 0: #parent + os._exit(0) + except OSError as e: + raise Exception("%s [%d]" % (e.strerror, e.errno)) - def importone(self, version, pkgarch, checksum, value): - return self.connection.importone(version, pkgarch, checksum, value) + os.chdir("/") - def getinfo(self): - return self.host, self.port + sys.stdout.flush() + sys.stderr.flush() -def start_daemon(dbfile, host, port, logfile): + # We could be called from a python thread with io.StringIO as + # stdout/stderr or it could be 'real' unix fd forking where we need + # to physically close the fds to prevent the program launching us from + # potentially hanging on a pipe. Handle both cases. + si = open("/dev/null", "r") + try: + os.dup2(si.fileno(), sys.stdin.fileno()) + except (AttributeError, io.UnsupportedOperation): + sys.stdin = si + so = open(logfile, "a+") + try: + os.dup2(so.fileno(), sys.stdout.fileno()) + except (AttributeError, io.UnsupportedOperation): + sys.stdout = so + try: + os.dup2(so.fileno(), sys.stderr.fileno()) + except (AttributeError, io.UnsupportedOperation): + sys.stderr = so + + # Clear out all log handlers prior to the fork() to avoid calling + # event handlers not part of the PRserver + for logger_iter in logging.Logger.manager.loggerDict.keys(): + logging.getLogger(logger_iter).handlers = [] + + # Ensure logging makes it to the logfile + streamhandler = logging.StreamHandler() + streamhandler.setLevel(logging.DEBUG) + formatter = bb.msg.BBLogFormatter("%(levelname)s: %(message)s") + streamhandler.setFormatter(formatter) + logger.addHandler(streamhandler) + + # write pidfile + pid = str(os.getpid()) + with open(pidfile, "w") as pf: + pf.write("%s\n" % pid) + + func() + os.remove(pidfile) + os._exit(0) + +def start_daemon(dbfile, host, port, logfile, read_only=False, upstream=None): ip = socket.gethostbyname(host) pidfile = PIDPREFIX % (ip, port) try: @@ -362,15 +334,13 @@ def start_daemon(dbfile, host, port, logfile): % pidfile) return 1 - server = PRServer(os.path.abspath(dbfile), os.path.abspath(logfile), (ip,port)) - server.start() + dbfile = os.path.abspath(dbfile) + def daemon_main(): + server = PRServer(dbfile, read_only=read_only, upstream=upstream) + server.start_tcp_server(ip, port) + server.serve_forever() - # Sometimes, the port (i.e. localhost:0) indicated by the user does not match with - # the one the server actually is listening, so at least warn the user about it - _,rport = server.getinfo() - if port != rport: - sys.stdout.write("Server is listening at port %s instead of %s\n" - % (rport,port)) + run_as_daemon(daemon_main, pidfile, os.path.abspath(logfile)) return 0 def stop_daemon(host, port): @@ -388,37 +358,28 @@ def stop_daemon(host, port): # so at least advise the user which ports the corresponding server is listening ports = [] portstr = "" - for pf in glob.glob(PIDPREFIX % (ip,'*')): + for pf in glob.glob(PIDPREFIX % (ip, "*")): bn = os.path.basename(pf) root, _ = os.path.splitext(bn) - ports.append(root.split('_')[-1]) + ports.append(root.split("_")[-1]) if len(ports): - portstr = "Wrong port? Other ports listening at %s: %s" % (host, ' '.join(ports)) + portstr = "Wrong port? Other ports listening at %s: %s" % (host, " ".join(ports)) sys.stderr.write("pidfile %s does not exist. Daemon not running? %s\n" - % (pidfile,portstr)) + % (pidfile, portstr)) return 1 try: - PRServerConnection(ip, port).terminate() - except: - logger.critical("Stop PRService %s:%d failed" % (host,port)) - - try: - if pid: - wait_timeout = 0 - print("Waiting for pr-server to exit.") - while is_running(pid) and wait_timeout < 50: - time.sleep(0.1) - wait_timeout += 1 + if is_running(pid): + print("Sending SIGTERM to pr-server.") + os.kill(pid, signal.SIGTERM) + time.sleep(0.1) - if is_running(pid): - print("Sending SIGTERM to pr-server.") - os.kill(pid,signal.SIGTERM) - time.sleep(0.1) - - if os.path.exists(pidfile): - os.remove(pidfile) + try: + os.remove(pidfile) + except FileNotFoundError: + # The PID file might have been removed by the exiting process + pass except OSError as e: err = str(e) @@ -436,7 +397,7 @@ def is_running(pid): return True def is_local_special(host, port): - if host.strip().upper() == 'localhost'.upper() and (not port): + if (host == "localhost" or host == "127.0.0.1") and not port: return True else: return False @@ -447,7 +408,7 @@ class PRServiceConfigError(Exception): def auto_start(d): global singleton - host_params = list(filter(None, (d.getVar('PRSERV_HOST') or '').split(':'))) + host_params = list(filter(None, (d.getVar("PRSERV_HOST") or "").split(":"))) if not host_params: # Shutdown any existing PR Server auto_shutdown() @@ -456,11 +417,16 @@ def auto_start(d): if len(host_params) != 2: # Shutdown any existing PR Server auto_shutdown() - logger.critical('\n'.join(['PRSERV_HOST: incorrect format', + logger.critical("\n".join(["PRSERV_HOST: incorrect format", 'Usage: PRSERV_HOST = "<hostname>:<port>"'])) raise PRServiceConfigError - if is_local_special(host_params[0], int(host_params[1])): + host = host_params[0].strip().lower() + port = int(host_params[1]) + + upstream = d.getVar("PRSERV_UPSTREAM") or None + + if is_local_special(host, port): import bb.utils cachedir = (d.getVar("PERSISTENT_DIR") or d.getVar("CACHE")) if not cachedir: @@ -474,39 +440,43 @@ def auto_start(d): auto_shutdown() if not singleton: bb.utils.mkdirhier(cachedir) - singleton = PRServSingleton(os.path.abspath(dbfile), os.path.abspath(logfile), ("localhost",0)) + singleton = PRServSingleton(os.path.abspath(dbfile), os.path.abspath(logfile), host, port, upstream) singleton.start() if singleton: - host, port = singleton.getinfo() - else: - host = host_params[0] - port = int(host_params[1]) + host = singleton.host + port = singleton.port try: - connection = PRServerConnection(host,port) - connection.ping() - realhost, realport = connection.getinfo() - return str(realhost) + ":" + str(realport) - + ping(host, port) + return str(host) + ":" + str(port) + except Exception: logger.critical("PRservice %s:%d not available" % (host, port)) raise PRServiceConfigError def auto_shutdown(): global singleton - if singleton: - host, port = singleton.getinfo() - try: - PRServerConnection(host, port).terminate() - except: - logger.critical("Stop PRService %s:%d failed" % (host,port)) - - try: - os.waitpid(singleton.prserv.pid, 0) - except ChildProcessError: - pass + if singleton and singleton.process: + singleton.process.terminate() + singleton.process.join() singleton = None def ping(host, port): - conn=PRServerConnection(host, port) - return conn.ping() + from . import client + + with client.PRClient() as conn: + conn.connect_tcp(host, port) + return conn.ping() + +def connect(host, port): + from . import client + + global singleton + + if host.strip().lower() == "localhost" and not port: + host = "localhost" + port = singleton.port + + conn = client.PRClient() + conn.connect_tcp(host, port) + return conn diff --git a/lib/prserv/tests.py b/lib/prserv/tests.py new file mode 100644 index 000000000..8765b129f --- /dev/null +++ b/lib/prserv/tests.py @@ -0,0 +1,386 @@ +#! /usr/bin/env python3 +# +# Copyright (C) 2024 BitBake Contributors +# +# SPDX-License-Identifier: GPL-2.0-only +# + +from . import create_server, create_client, increase_revision, revision_greater, revision_smaller, _revision_greater_or_equal +import prserv.db as db +from bb.asyncrpc import InvokeError +import logging +import os +import sys +import tempfile +import unittest +import socket +import subprocess +from pathlib import Path + +THIS_DIR = Path(__file__).parent +BIN_DIR = THIS_DIR.parent.parent / "bin" + +version = "dummy-1.0-r0" +pkgarch = "core2-64" +other_arch = "aarch64" + +checksumX = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4f0" +checksum0 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a0" +checksum1 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a1" +checksum2 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a2" +checksum3 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a3" +checksum4 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a4" +checksum5 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a5" +checksum6 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a6" +checksum7 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a7" +checksum8 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a8" +checksum9 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4a9" +checksum10 = "51bf8189dbe9ea81fa6dd89608bf19380c437a9cf12f6c6239887801ba4ab4aa" + +def server_prefunc(server, name): + logging.basicConfig(level=logging.DEBUG, filename='prserv-%s.log' % name, filemode='w', + format='%(levelname)s %(filename)s:%(lineno)d %(message)s') + server.logger.debug("Running server %s" % name) + sys.stdout = open('prserv-stdout-%s.log' % name, 'w') + sys.stderr = sys.stdout + +class PRTestSetup(object): + + def start_server(self, name, dbfile, upstream=None, read_only=False, prefunc=server_prefunc): + + def cleanup_server(server): + if server.process.exitcode is not None: + return + server.process.terminate() + server.process.join() + + server = create_server(socket.gethostbyname("localhost") + ":0", + dbfile, + upstream=upstream, + read_only=read_only) + + server.serve_as_process(prefunc=prefunc, args=(name,)) + self.addCleanup(cleanup_server, server) + + return server + + def start_client(self, server_address): + def cleanup_client(client): + client.close() + + client = create_client(server_address) + self.addCleanup(cleanup_client, client) + + return client + +class FunctionTests(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-prserv') + self.addCleanup(self.temp_dir.cleanup) + + def test_increase_revision(self): + self.assertEqual(increase_revision("1"), "2") + self.assertEqual(increase_revision("1.0"), "1.1") + self.assertEqual(increase_revision("1.1.1"), "1.1.2") + self.assertEqual(increase_revision("1.1.1.3"), "1.1.1.4") + self.assertRaises(ValueError, increase_revision, "1.a") + self.assertRaises(ValueError, increase_revision, "1.") + self.assertRaises(ValueError, increase_revision, "") + + def test_revision_greater_or_equal(self): + self.assertTrue(_revision_greater_or_equal("2", "2")) + self.assertTrue(_revision_greater_or_equal("2", "1")) + self.assertTrue(_revision_greater_or_equal("10", "2")) + self.assertTrue(_revision_greater_or_equal("1.10", "1.2")) + self.assertFalse(_revision_greater_or_equal("1.2", "1.10")) + self.assertTrue(_revision_greater_or_equal("1.10", "1")) + self.assertTrue(_revision_greater_or_equal("1.10.1", "1.10")) + self.assertFalse(_revision_greater_or_equal("1.10.1", "1.10.2")) + self.assertTrue(_revision_greater_or_equal("1.10.1", "1.10.1")) + self.assertTrue(_revision_greater_or_equal("1.10.1", "1")) + self.assertTrue(revision_greater("1.20", "1.3")) + self.assertTrue(revision_smaller("1.3", "1.20")) + + # DB tests + + def test_db(self): + dbfile = os.path.join(self.temp_dir.name, "testtable.sqlite3") + + self.db = db.PRData(dbfile) + self.table = self.db["PRMAIN"] + + self.table.store_value(version, pkgarch, checksum0, "0") + self.table.store_value(version, pkgarch, checksum1, "1") + # "No history" mode supports multiple PRs for the same checksum + self.table.store_value(version, pkgarch, checksum0, "2") + self.table.store_value(version, pkgarch, checksum2, "1.0") + + self.assertTrue(self.table.test_package(version, pkgarch)) + self.assertFalse(self.table.test_package(version, other_arch)) + + self.assertTrue(self.table.test_value(version, pkgarch, "0")) + self.assertTrue(self.table.test_value(version, pkgarch, "1")) + self.assertTrue(self.table.test_value(version, pkgarch, "2")) + + self.assertEqual(self.table.find_package_max_value(version, pkgarch), "2") + + self.assertEqual(self.table.find_min_value(version, pkgarch, checksum0), "0") + self.assertEqual(self.table.find_max_value(version, pkgarch, checksum0), "2") + + # Test history modes + self.assertEqual(self.table.find_value(version, pkgarch, checksum0, True), "0") + self.assertEqual(self.table.find_value(version, pkgarch, checksum0, False), "2") + + self.assertEqual(self.table.find_new_subvalue(version, pkgarch, "3"), "3.0") + self.assertEqual(self.table.find_new_subvalue(version, pkgarch, "1"), "1.1") + + # Revision comparison tests + self.table.store_value(version, pkgarch, checksum1, "1.3") + self.table.store_value(version, pkgarch, checksum1, "1.20") + self.assertEqual(self.table.find_min_value(version, pkgarch, checksum1), "1") + self.assertEqual(self.table.find_max_value(version, pkgarch, checksum1), "1.20") + +class PRBasicTests(PRTestSetup, unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-prserv') + self.addCleanup(self.temp_dir.cleanup) + + dbfile = os.path.join(self.temp_dir.name, "prtest-basic.sqlite3") + + self.server1 = self.start_server("basic", dbfile) + self.client1 = self.start_client(self.server1.address) + + def test_basic(self): + + # Checks on non existing configuration + + result = self.client1.test_pr(version, pkgarch, checksum0) + self.assertIsNone(result, "test_pr should return 'None' for a non existing PR") + + result = self.client1.test_package(version, pkgarch) + self.assertFalse(result, "test_package should return 'False' for a non existing PR") + + result = self.client1.max_package_pr(version, pkgarch) + self.assertIsNone(result, "max_package_pr should return 'None' for a non existing PR") + + # Add a first configuration + + result = self.client1.getPR(version, pkgarch, checksum0) + self.assertEqual(result, "0", "getPR: initial PR of a package should be '0'") + + result = self.client1.test_pr(version, pkgarch, checksum0) + self.assertEqual(result, "0", "test_pr should return '0' here, matching the result of getPR") + + result = self.client1.test_package(version, pkgarch) + self.assertTrue(result, "test_package should return 'True' for an existing PR") + + result = self.client1.max_package_pr(version, pkgarch) + self.assertEqual(result, "0", "max_package_pr should return '0' in the current test series") + + # Check that the same request gets the same value + + result = self.client1.getPR(version, pkgarch, checksum0) + self.assertEqual(result, "0", "getPR: asking for the same PR a second time in a row should return the same value.") + + # Add new configurations + + result = self.client1.getPR(version, pkgarch, checksum1) + self.assertEqual(result, "1", "getPR: second PR of a package should be '1'") + + result = self.client1.test_pr(version, pkgarch, checksum1) + self.assertEqual(result, "1", "test_pr should return '1' here, matching the result of getPR") + + result = self.client1.max_package_pr(version, pkgarch) + self.assertEqual(result, "1", "max_package_pr should return '1' in the current test series") + + result = self.client1.getPR(version, pkgarch, checksum2) + self.assertEqual(result, "2", "getPR: second PR of a package should be '2'") + + result = self.client1.test_pr(version, pkgarch, checksum2) + self.assertEqual(result, "2", "test_pr should return '2' here, matching the result of getPR") + + result = self.client1.max_package_pr(version, pkgarch) + self.assertEqual(result, "2", "max_package_pr should return '2' in the current test series") + + result = self.client1.getPR(version, pkgarch, checksum3) + self.assertEqual(result, "3", "getPR: second PR of a package should be '3'") + + result = self.client1.test_pr(version, pkgarch, checksum3) + self.assertEqual(result, "3", "test_pr should return '3' here, matching the result of getPR") + + result = self.client1.max_package_pr(version, pkgarch) + self.assertEqual(result, "3", "max_package_pr should return '3' in the current test series") + + # Ask again for the first configuration + + result = self.client1.getPR(version, pkgarch, checksum0) + self.assertEqual(result, "4", "getPR: should return '4' in this configuration") + + # Ask again with explicit "no history" mode + + result = self.client1.getPR(version, pkgarch, checksum0, False) + self.assertEqual(result, "4", "getPR: should return '4' in this configuration") + + # Ask again with explicit "history" mode. This should return the first recorded PR for checksum0 + + result = self.client1.getPR(version, pkgarch, checksum0, True) + self.assertEqual(result, "0", "getPR: should return '0' in this configuration") + + # Check again that another pkgarg resets the counters + + result = self.client1.test_pr(version, other_arch, checksum0) + self.assertIsNone(result, "test_pr should return 'None' for a non existing PR") + + result = self.client1.test_package(version, other_arch) + self.assertFalse(result, "test_package should return 'False' for a non existing PR") + + result = self.client1.max_package_pr(version, other_arch) + self.assertIsNone(result, "max_package_pr should return 'None' for a non existing PR") + + # Now add the configuration + + result = self.client1.getPR(version, other_arch, checksum0) + self.assertEqual(result, "0", "getPR: initial PR of a package should be '0'") + + result = self.client1.test_pr(version, other_arch, checksum0) + self.assertEqual(result, "0", "test_pr should return '0' here, matching the result of getPR") + + result = self.client1.test_package(version, other_arch) + self.assertTrue(result, "test_package should return 'True' for an existing PR") + + result = self.client1.max_package_pr(version, other_arch) + self.assertEqual(result, "0", "max_package_pr should return '0' in the current test series") + + result = self.client1.is_readonly() + self.assertFalse(result, "Server should not be described as 'read-only'") + +class PRUpstreamTests(PRTestSetup, unittest.TestCase): + + def setUp(self): + + self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-prserv') + self.addCleanup(self.temp_dir.cleanup) + + dbfile2 = os.path.join(self.temp_dir.name, "prtest-upstream2.sqlite3") + self.server2 = self.start_server("upstream2", dbfile2) + self.client2 = self.start_client(self.server2.address) + + dbfile1 = os.path.join(self.temp_dir.name, "prtest-upstream1.sqlite3") + self.server1 = self.start_server("upstream1", dbfile1, upstream=self.server2.address) + self.client1 = self.start_client(self.server1.address) + + dbfile0 = os.path.join(self.temp_dir.name, "prtest-local.sqlite3") + self.server0 = self.start_server("local", dbfile0, upstream=self.server1.address) + self.client0 = self.start_client(self.server0.address) + self.shared_db = dbfile0 + + def test_upstream_and_readonly(self): + + # For identical checksums, all servers should return the same PR + + result = self.client2.getPR(version, pkgarch, checksum0) + self.assertEqual(result, "0", "getPR: initial PR of a package should be '0'") + + result = self.client1.getPR(version, pkgarch, checksum0) + self.assertEqual(result, "0", "getPR: initial PR of a package should be '0' (same as upstream)") + + result = self.client0.getPR(version, pkgarch, checksum0) + self.assertEqual(result, "0", "getPR: initial PR of a package should be '0' (same as upstream)") + + # Now introduce new checksums on server1 for, same version + + result = self.client1.getPR(version, pkgarch, checksum1) + self.assertEqual(result, "0.0", "getPR: first PR of a package which has a different checksum upstream should be '0.0'") + + result = self.client1.getPR(version, pkgarch, checksum2) + self.assertEqual(result, "0.1", "getPR: second PR of a package that has a different checksum upstream should be '0.1'") + + # Now introduce checksums on server0 for, same version + + result = self.client1.getPR(version, pkgarch, checksum1) + self.assertEqual(result, "0.2", "getPR: can't decrease for known PR") + + result = self.client1.getPR(version, pkgarch, checksum2) + self.assertEqual(result, "0.3") + + result = self.client1.max_package_pr(version, pkgarch) + self.assertEqual(result, "0.3") + + result = self.client0.getPR(version, pkgarch, checksum3) + self.assertEqual(result, "0.3.0", "getPR: first PR of a package that doesn't exist upstream should be '0.3.0'") + + result = self.client0.getPR(version, pkgarch, checksum4) + self.assertEqual(result, "0.3.1", "getPR: second PR of a package that doesn't exist upstream should be '0.3.1'") + + result = self.client0.getPR(version, pkgarch, checksum3) + self.assertEqual(result, "0.3.2") + + # More upstream updates + # Here, we assume no communication between server2 and server0. server2 only impacts server0 + # after impacting server1 + + self.assertEqual(self.client2.getPR(version, pkgarch, checksum5), "1") + self.assertEqual(self.client1.getPR(version, pkgarch, checksum6), "1.0") + self.assertEqual(self.client1.getPR(version, pkgarch, checksum7), "1.1") + self.assertEqual(self.client0.getPR(version, pkgarch, checksum8), "1.1.0") + self.assertEqual(self.client0.getPR(version, pkgarch, checksum9), "1.1.1") + + # "history" mode tests + + self.assertEqual(self.client2.getPR(version, pkgarch, checksum0, True), "0") + self.assertEqual(self.client1.getPR(version, pkgarch, checksum2, True), "0.1") + self.assertEqual(self.client0.getPR(version, pkgarch, checksum3, True), "0.3.0") + + # More "no history" mode tests + + self.assertEqual(self.client2.getPR(version, pkgarch, checksum0), "2") + self.assertEqual(self.client1.getPR(version, pkgarch, checksum0), "2") # Same as upstream + self.assertEqual(self.client0.getPR(version, pkgarch, checksum0), "2") # Same as upstream + self.assertEqual(self.client1.getPR(version, pkgarch, checksum7), "3") # This could be surprising, but since the previous revision was "2", increasing it yields "3". + # We don't know how many upstream servers we have + # Start read-only server with server1 as upstream + self.server_ro = self.start_server("local-ro", self.shared_db, upstream=self.server1.address, read_only=True) + self.client_ro = self.start_client(self.server_ro.address) + + self.assertTrue(self.client_ro.is_readonly(), "Database should be described as 'read-only'") + + # Checks on non existing configurations + self.assertIsNone(self.client_ro.test_pr(version, pkgarch, checksumX)) + self.assertFalse(self.client_ro.test_package("unknown", pkgarch)) + + # Look up existing configurations + self.assertEqual(self.client_ro.getPR(version, pkgarch, checksum0), "3") # "no history" mode + self.assertEqual(self.client_ro.getPR(version, pkgarch, checksum0, True), "0") # "history" mode + self.assertEqual(self.client_ro.getPR(version, pkgarch, checksum3), "3") + self.assertEqual(self.client_ro.getPR(version, pkgarch, checksum3, True), "0.3.0") + self.assertEqual(self.client_ro.max_package_pr(version, pkgarch), "2") # normal as "3" was never saved + + # Try to insert a new value. Here this one is know upstream. + self.assertEqual(self.client_ro.getPR(version, pkgarch, checksum7), "3") + # Try to insert a completely new value. As the max upstream value is already "3", it should be "3.0" + self.assertEqual(self.client_ro.getPR(version, pkgarch, checksum10), "3.0") + # Same with another value which only exists in the upstream upstream server + # This time, as the upstream server doesn't know it, it will ask its upstream server. So that's a known one. + self.assertEqual(self.client_ro.getPR(version, pkgarch, checksum9), "3") + +class ScriptTests(unittest.TestCase): + + def setUp(self): + + self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-prserv') + self.addCleanup(self.temp_dir.cleanup) + self.dbfile = os.path.join(self.temp_dir.name, "prtest.sqlite3") + + def test_1_start_bitbake_prserv(self): + try: + subprocess.check_call([BIN_DIR / "bitbake-prserv", "--start", "-f", self.dbfile]) + except subprocess.CalledProcessError as e: + self.fail("Failed to start bitbake-prserv: %s" % e.returncode) + + def test_2_stop_bitbake_prserv(self): + try: + subprocess.check_call([BIN_DIR / "bitbake-prserv", "--stop"]) + except subprocess.CalledProcessError as e: + self.fail("Failed to stop bitbake-prserv: %s" % e.returncode) |