Skip to content

Commit

Permalink
main: Support specifying db spec explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
spbnick committed Jan 16, 2024
1 parent 3215ebb commit 6b515e4
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
# Maximum time for pulling maximum amount of submissions from the queue
LOAD_QUEUE_TIMEOUT_SEC = float(os.environ["KCIDB_LOAD_QUEUE_TIMEOUT_SEC"])

# The specification for the database submissions should be loaded into
DATABASE = os.environ["KCIDB_DATABASE"]

# Minimum time between loading submissions into the database
DATABASE_LOAD_PERIOD = datetime.timedelta(
seconds=int(os.environ["KCIDB_DATABASE_LOAD_PERIOD_SEC"])
Expand Down Expand Up @@ -102,9 +105,14 @@ def get_smtp_password():
return _SMTP_PASSWORD


def get_load_queue_subscriber():
def get_load_queue_subscriber(database):
"""
Create or get the cached subscriber object for the submission queue.
Args:
database: The specification for the database the data from the queue
will be loaded into. Used to retrieve the accepted I/O
schema version.
"""
# It's alright, pylint: disable=global-statement
global _LOAD_QUEUE_SUBSCRIBER
Expand All @@ -113,7 +121,7 @@ def get_load_queue_subscriber():
PROJECT_ID,
os.environ["KCIDB_LOAD_QUEUE_TOPIC"],
os.environ["KCIDB_LOAD_QUEUE_SUBSCRIPTION"],
schema=get_db_client().get_schema()[1]
schema=get_db_client(database).get_schema()[1]
)
return _LOAD_QUEUE_SUBSCRIBER

Expand Down Expand Up @@ -150,24 +158,37 @@ def get_db_credentials():
atexit.register(os.remove, pgpass_filename)


def get_db_client():
"""Create or retrieve the cached database client."""
def get_db_client(database):
"""
Create or retrieve the cached database client.
Args:
database: The specification for the database the client should
connect to.
"""
# It's alright, pylint: disable=global-statement
global _DB_CLIENT
if _DB_CLIENT is None:
# Get the credentials
get_db_credentials()
# Create the client
_DB_CLIENT = kcidb.db.Client(os.environ["KCIDB_DATABASE"])
_DB_CLIENT = kcidb.db.Client(database)
assert _DB_CLIENT.database == database
return _DB_CLIENT


def get_oo_client():
"""Create or retrieve the cached OO database client."""
def get_oo_client(database):
"""
Create or retrieve the cached OO database client.
Args:
database: The specification for the database the client should
connect to.
"""
# It's alright, pylint: disable=global-statement
global _OO_CLIENT
if _OO_CLIENT is None:
_OO_CLIENT = kcidb.oo.Client(get_db_client())
_OO_CLIENT = kcidb.oo.Client(get_db_client(database))
return _OO_CLIENT


Expand Down Expand Up @@ -255,8 +276,8 @@ def kcidb_load_queue(event, context):
if it stayed unmodified for at least DATABASE_LOAD_PERIOD.
"""
# pylint: disable=too-many-locals
subscriber = get_load_queue_subscriber()
db_client = get_db_client()
subscriber = get_load_queue_subscriber(DATABASE)
db_client = get_db_client(DATABASE)
io_schema = db_client.get_schema()[1]
publisher = get_updated_queue_publisher()
# Do nothing, if updated recently
Expand Down Expand Up @@ -336,7 +357,7 @@ def kcidb_spool_notifications(event, context):
Spool notifications about objects matching patterns arriving from a Pub
Sub subscription
"""
oo_client = get_oo_client()
oo_client = get_oo_client(DATABASE)
spool_client = get_spool_client()
# Reset the ORM cache
oo_client.reset_cache()
Expand Down

0 comments on commit 6b515e4

Please sign in to comment.