From 6b515e48f4c43647971d4b518b751365a40bbaee Mon Sep 17 00:00:00 2001 From: Nikolai Kondrashov Date: Wed, 10 Jan 2024 18:54:04 +0200 Subject: [PATCH] main: Support specifying db spec explicitly --- main.py | 43 ++++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 509d5d39..78512cda 100644 --- a/main.py +++ b/main.py @@ -28,6 +28,9 @@ # Maximum time for pulling maximum amount of submissions from the queue LOAD_QUEUE_TIMEOUT_SEC = float(os.environ["KCIDB_LOAD_QUEUE_TIMEOUT_SEC"]) +# The specification for the database submissions should be loaded into +DATABASE = os.environ["KCIDB_DATABASE"] + # Minimum time between loading submissions into the database DATABASE_LOAD_PERIOD = datetime.timedelta( seconds=int(os.environ["KCIDB_DATABASE_LOAD_PERIOD_SEC"]) @@ -102,9 +105,14 @@ def get_smtp_password(): return _SMTP_PASSWORD -def get_load_queue_subscriber(): +def get_load_queue_subscriber(database): """ Create or get the cached subscriber object for the submission queue. + + Args: + database: The specification for the database the data from the queue + will be loaded into. Used to retrieve the accepted I/O + schema version. """ # It's alright, pylint: disable=global-statement global _LOAD_QUEUE_SUBSCRIBER @@ -113,7 +121,7 @@ def get_load_queue_subscriber(): PROJECT_ID, os.environ["KCIDB_LOAD_QUEUE_TOPIC"], os.environ["KCIDB_LOAD_QUEUE_SUBSCRIPTION"], - schema=get_db_client().get_schema()[1] + schema=get_db_client(database).get_schema()[1] ) return _LOAD_QUEUE_SUBSCRIBER @@ -150,24 +158,37 @@ def get_db_credentials(): atexit.register(os.remove, pgpass_filename) -def get_db_client(): - """Create or retrieve the cached database client.""" +def get_db_client(database): + """ + Create or retrieve the cached database client. + + Args: + database: The specification for the database the client should + connect to. + """ # It's alright, pylint: disable=global-statement global _DB_CLIENT if _DB_CLIENT is None: # Get the credentials get_db_credentials() # Create the client - _DB_CLIENT = kcidb.db.Client(os.environ["KCIDB_DATABASE"]) + _DB_CLIENT = kcidb.db.Client(database) + assert _DB_CLIENT.database == database return _DB_CLIENT -def get_oo_client(): - """Create or retrieve the cached OO database client.""" +def get_oo_client(database): + """ + Create or retrieve the cached OO database client. + + Args: + database: The specification for the database the client should + connect to. + """ # It's alright, pylint: disable=global-statement global _OO_CLIENT if _OO_CLIENT is None: - _OO_CLIENT = kcidb.oo.Client(get_db_client()) + _OO_CLIENT = kcidb.oo.Client(get_db_client(database)) return _OO_CLIENT @@ -255,8 +276,8 @@ def kcidb_load_queue(event, context): if it stayed unmodified for at least DATABASE_LOAD_PERIOD. """ # pylint: disable=too-many-locals - subscriber = get_load_queue_subscriber() - db_client = get_db_client() + subscriber = get_load_queue_subscriber(DATABASE) + db_client = get_db_client(DATABASE) io_schema = db_client.get_schema()[1] publisher = get_updated_queue_publisher() # Do nothing, if updated recently @@ -336,7 +357,7 @@ def kcidb_spool_notifications(event, context): Spool notifications about objects matching patterns arriving from a Pub Sub subscription """ - oo_client = get_oo_client() + oo_client = get_oo_client(DATABASE) spool_client = get_spool_client() # Reset the ORM cache oo_client.reset_cache()