From 0bacf6551821beb8915513b120ae672ae8eb1612 Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Fri, 12 Apr 2024 09:57:09 -0600 Subject: siggen: Capture SSL environment for hashserver Now that the bitbake hash server supports SSL connections, we need to capture a few environment variables which can affect the ability to connect via SSL. Note that the variables are only put in place to affect the environment while actually invoking the server [RP: Tweak to use BB_ORIGENV as well] [RP: Tweak to handle os.environ restore correctly] Signed-off-by: Joshua Watt Signed-off-by: Richard Purdie --- lib/bb/siggen.py | 94 +++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 24 deletions(-) diff --git a/lib/bb/siggen.py b/lib/bb/siggen.py index 2a0ecf57e..8ab08ec96 100644 --- a/lib/bb/siggen.py +++ b/lib/bb/siggen.py @@ -15,6 +15,7 @@ import difflib import simplediff import json import types +from contextlib import contextmanager import bb.compress.zstd from bb.checksum import FileChecksumCache from bb import runqueue @@ -28,6 +29,14 @@ hashequiv_logger = logging.getLogger('BitBake.SigGen.HashEquiv') # The minimum version of the find_siginfo function we need find_siginfo_minversion = 2 +HASHSERV_ENVVARS = [ + "SSL_CERT_DIR", + "SSL_CERT_FILE", + "NO_PROXY", + "HTTPS_PROXY", + "HTTP_PROXY" +] + def check_siggen_version(siggen): if not hasattr(siggen, "find_siginfo_version"): bb.fatal("Siggen from metadata (OE-Core?) is too old, please update it (no version found)") @@ -537,14 +546,23 @@ class SignatureGeneratorUniHashMixIn(object): self.unihash_exists_cache = set() self.username = None self.password = None + self.env = {} + + origenv = data.getVar("BB_ORIGENV") + for e in HASHSERV_ENVVARS: + value = data.getVar(e) + if not value and origenv: + value = origenv.getVar(e) + if value: + self.env[e] = value super().__init__(data) def get_taskdata(self): - return (self.server, self.method, self.extramethod, self.max_parallel, self.username, self.password) + super().get_taskdata() + return (self.server, self.method, self.extramethod, self.max_parallel, self.username, self.password, self.env) + super().get_taskdata() def set_taskdata(self, data): - self.server, self.method, self.extramethod, self.max_parallel, self.username, self.password = data[:6] - super().set_taskdata(data[6:]) + self.server, self.method, self.extramethod, self.max_parallel, self.username, self.password, self.env = data[:7] + super().set_taskdata(data[7:]) def get_hashserv_creds(self): if self.username and self.password: @@ -555,15 +573,34 @@ class SignatureGeneratorUniHashMixIn(object): return {} + @contextmanager + def _client_env(self): + orig_env = os.environ.copy() + try: + for k, v in self.env.items(): + os.environ[k] = v + + yield + finally: + for k, v in self.env.items(): + if k in orig_env: + os.environ[k] = orig_env[k] + else: + del os.environ[k] + + @contextmanager def client(self): - if getattr(self, '_client', None) is None: - self._client = hashserv.create_client(self.server, **self.get_hashserv_creds()) - return self._client + with self._client_env(): + if getattr(self, '_client', None) is None: + self._client = hashserv.create_client(self.server, **self.get_hashserv_creds()) + yield self._client + @contextmanager def client_pool(self): - if getattr(self, '_client_pool', None) is None: - self._client_pool = hashserv.client.ClientPool(self.server, self.max_parallel, **self.get_hashserv_creds()) - return self._client_pool + with self._client_env(): + if getattr(self, '_client_pool', None) is None: + self._client_pool = hashserv.client.ClientPool(self.server, self.max_parallel, **self.get_hashserv_creds()) + yield self._client_pool def reset(self, data): self.__close_clients() @@ -574,12 +611,13 @@ class SignatureGeneratorUniHashMixIn(object): return super().exit() def __close_clients(self): - if getattr(self, '_client', None) is not None: - self._client.close() - self._client = None - if getattr(self, '_client_pool', None) is not None: - self._client_pool.close() - self._client_pool = None + with self._client_env(): + if getattr(self, '_client', None) is not None: + self._client.close() + self._client = None + if getattr(self, '_client_pool', None) is not None: + self._client_pool.close() + self._client_pool = None def get_stampfile_hash(self, tid): if tid in self.taskhash: @@ -650,11 +688,13 @@ class SignatureGeneratorUniHashMixIn(object): if self.max_parallel <= 1 or len(uncached_query) <= 1: # No parallelism required. Make the query serially with the single client - uncached_result = { - key: self.client().unihash_exists(value) for key, value in uncached_query.items() - } + with self.client() as client: + uncached_result = { + key: client.unihash_exists(value) for key, value in uncached_query.items() + } else: - uncached_result = self.client_pool().unihashes_exist(uncached_query) + with self.client_pool() as client_pool: + uncached_result = client_pool.unihashes_exist(uncached_query) for key, exists in uncached_result.items(): if exists: @@ -687,10 +727,12 @@ class SignatureGeneratorUniHashMixIn(object): if self.max_parallel <= 1 or len(queries) <= 1: # No parallelism required. Make the query serially with the single client - for tid, args in queries.items(): - query_result[tid] = self.client().get_unihash(*args) + with self.client() as client: + for tid, args in queries.items(): + query_result[tid] = client.get_unihash(*args) else: - query_result = self.client_pool().get_unihashes(queries) + with self.client_pool() as client_pool: + query_result = client_pool.get_unihashes(queries) for tid, unihash in query_result.items(): # In the absence of being able to discover a unique hash from the @@ -785,7 +827,9 @@ class SignatureGeneratorUniHashMixIn(object): if tid in self.extramethod: method = method + self.extramethod[tid] - data = self.client().report_unihash(taskhash, method, outhash, unihash, extra_data) + with self.client() as client: + data = client.report_unihash(taskhash, method, outhash, unihash, extra_data) + new_unihash = data['unihash'] if new_unihash != unihash: @@ -816,7 +860,9 @@ class SignatureGeneratorUniHashMixIn(object): if tid in self.extramethod: method = method + self.extramethod[tid] - data = self.client().report_unihash_equiv(taskhash, method, wanted_unihash, extra_data) + with self.client() as client: + data = client.report_unihash_equiv(taskhash, method, wanted_unihash, extra_data) + hashequiv_logger.verbose('Reported task %s as unihash %s to %s (%s)' % (tid, wanted_unihash, self.server, str(data))) if data is None: -- cgit 1.2.3-korg