From d20b2a34237763844d4501a77cab59e4df6ae962 Mon Sep 17 00:00:00 2001 From: Jon Couldridge Date: Fri, 24 Jan 2025 10:05:08 +0000 Subject: [PATCH] short db driver expansion now in a function and used by tests (#53) --- src/hutch_bunny/core/setting_database.py | 28 +++++++++++++------ tests/test_demographics_distribution_query.py | 5 +++- tests/test_return.py | 5 +++- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/hutch_bunny/core/setting_database.py b/src/hutch_bunny/core/setting_database.py index 229d5e9..26337da 100644 --- a/src/hutch_bunny/core/setting_database.py +++ b/src/hutch_bunny/core/setting_database.py @@ -4,6 +4,23 @@ import hutch_bunny.core.settings as settings +def expand_short_drivers(drivername: str): + """ + Expand unqualified "short" db driver names when necessary so we can override sqlalchemy + e.g. when using psycopg3, expand `postgresql` explicitly rather than use sqlalchemy's default of psycopg2 + """ + + if drivername == "postgresql": + return settings.DEFAULT_POSTGRES_DRIVER + + if drivername == "mssql": + return settings.DEFAULT_MSSQL_DRIVER + + # Add other explicit driver qualification as needed ... + + return drivername + + def setting_database(logger: Logger): logger.info("Setting up database connection...") @@ -24,17 +41,10 @@ def setting_database(logger: Logger): exit() else: datasource_db_port = environ.get("DATASOURCE_DB_PORT") - datasource_db_drivername = environ.get( - "DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER + datasource_db_drivername = expand_short_drivers( + environ.get("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER) ) - # expand postgres to a full default driver, so we can override sqlalchemy - if datasource_db_drivername == "postgresql": - datasource_db_drivername = settings.DEFAULT_POSTGRES_DRIVER - - if datasource_db_drivername == "mssql": - datasource_db_drivername = settings.DEFAULT_MSSQL_DRIVER - try: db_manager = SyncDBManager( username=environ.get("DATASOURCE_DB_USERNAME"), diff --git a/tests/test_demographics_distribution_query.py b/tests/test_demographics_distribution_query.py index a06f5a4..03da36b 100644 --- a/tests/test_demographics_distribution_query.py +++ b/tests/test_demographics_distribution_query.py @@ -6,6 +6,7 @@ from dotenv import load_dotenv import os import hutch_bunny.core.settings as settings +import hutch_bunny.core.setting_database as db_settings load_dotenv() @@ -23,7 +24,9 @@ def db_manager(): host=os.getenv("DATASOURCE_DB_HOST"), port=(int(datasource_db_port) if datasource_db_port is not None else None), database=os.getenv("DATASOURCE_DB_DATABASE"), - drivername=os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER), + drivername=db_settings.expand_short_drivers( + os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER) + ), schema=os.getenv("DATASOURCE_DB_SCHEMA"), ) diff --git a/tests/test_return.py b/tests/test_return.py index dbaa999..3594854 100644 --- a/tests/test_return.py +++ b/tests/test_return.py @@ -11,6 +11,7 @@ from dotenv import load_dotenv import os import hutch_bunny.core.settings as settings +import hutch_bunny.core.setting_database as db_settings load_dotenv() @@ -28,7 +29,9 @@ def db_manager(): host=os.getenv("DATASOURCE_DB_HOST"), port=(int(datasource_db_port) if datasource_db_port is not None else None), database=os.getenv("DATASOURCE_DB_DATABASE"), - drivername=os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER), + drivername=db_settings.expand_short_drivers( + os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER) + ), schema=os.getenv("DATASOURCE_DB_SCHEMA"), )