From 44a253417fbd89c656c397e2dde44fc3f0debe9f Mon Sep 17 00:00:00 2001 From: Phil Quinlan Date: Thu, 23 Jan 2025 10:02:18 +0000 Subject: [PATCH 1/4] Documenting steps to help with #31 (#44) * Adding my understanding of process and functions via comments * Documented for the availability query, some todos and refactors. Code can be heavily optimised. * Comments added for generic distro --- src/hutch_bunny/core/query_solvers.py | 89 +++++++++++++++++++++++++-- 1 file changed, 84 insertions(+), 5 deletions(-) diff --git a/src/hutch_bunny/core/query_solvers.py b/src/hutch_bunny/core/query_solvers.py index 2248eb9..8085432 100644 --- a/src/hutch_bunny/core/query_solvers.py +++ b/src/hutch_bunny/core/query_solvers.py @@ -25,7 +25,7 @@ import hutch_bunny.core.settings as settings from hutch_bunny.core.constants import DISTRIBUTION_TYPE_FILE_NAMES_MAP - +# Class for availability queries class AvailibilityQuerySolver: subqueries = list() concept_table_map = { @@ -67,11 +67,25 @@ def __init__(self, db_manager: SyncDBManager, query: AvailabilityQuery) -> None: self.db_manager = db_manager self.query = query + """ Function that takes all the concept IDs in the cohort defintion, looks them up in the OMOP database + to extract the concept_id and domain and place this within a dictionary for lookup during other query building + + Although the query payload will tell you where the OMOP concept is from (based on the RQUEST OMOP version, this is + a safer method as we know concepts can move between tables based on a vocab. + + Therefore this helps to account for a difference between the Bunny vocab version and the RQUEST OMOP version. + + + #TODO: this does not cover the scenario that is possible to occur where the local vocab model may say the concept + should be based in one table but it is actually present in another + + """ def _find_concepts(self) -> dict: concept_ids = set() for group in self.query.cohort.groups: for rule in group.rules: concept_ids.add(int(rule.value)) + concept_query = ( # order must be .concept_id, .domain_id select(Concept.concept_id, Concept.domain_id) @@ -86,15 +100,39 @@ def _find_concepts(self) -> dict: } return concept_dict + """ Function for taking the JSON query from RQUEST and creating the required query to run against the OMOP database. + + RQUEST API spec can have multiple groups in each query, and then a condition between the groups. + + Each group can have conditional logic AND/OR within the group + + Each concept can either be an inclusion or exclusion criteria. + + Each concept can have an age set, so it is that this event with concept X occurred when + the person was between a certain age. - #TODO - not sure this is implemented here + + """ def _solve_rules(self) -> None: - """Find all rows that match the rules' criteria.""" + + #get the list of concepts to build the query constraints concepts = self._find_concepts() + + # This is related to the logic within a group. This is used in the subsequent for loop to determine how + # the merge should be applied. merge_method = lambda x: "inner" if x == "AND" else "outer" + + # iterate through all the groups specified in the query for group in self.query.cohort.groups: + + # todo - refactor variable name concept as this is misleading. It is not the concept but actually the domain of the concept + # this passes in the conceptID of but gets back the domain related to that concept. concept = concepts.get(group.rules[0].value) + concept_table = self.concept_table_map.get(concept) boolean_rule_col = self.boolean_rule_map.get(concept) numeric_rule_col = self.numeric_rule_map.get(concept) + + #within the query, if a range was specified, which is currently if ( group.rules[0].min_value is not None and group.rules[0].max_value is not None @@ -114,6 +152,10 @@ def _solve_rules(self) -> None: main_df = pd.read_sql_query( sql=stmnt, con=self.db_manager.engine.connect() ) + + # the next two ifs are basically switching between equals and not equals. These could be merged with a simple + # switch for the operator. + elif group.rules[0].operator == "=": stmnt = ( select(concept_table.person_id) @@ -132,11 +174,25 @@ def _solve_rules(self) -> None: main_df = pd.read_sql_query( sql=stmnt, con=self.db_manager.engine.connect() ) + + """ + Now that the main_df dataframe has been populated, the subsequent queries are created and merged into + main_df dataframe. That is why above the first concept is hard coded as accessing index 0 and why the for + loop below if start at index 1. The queries are almost identical to the above, exact same logic but + in order to facilitate the merging, a label is created on person id, so that the newly created data frame + can be merged with main_df via unique keys. + """ + for i, rule in enumerate(group.rules[1:], start=1): + + # todo - refactor variable name concept as this is misleading. It is not the concept but actually the domain of the concept + # this passes in the conceptID of but gets back the domain related to that concept. concept = concepts.get(rule.value) + concept_table = self.concept_table_map.get(concept) boolean_rule_col = self.boolean_rule_map.get(concept) numeric_rule_col = self.numeric_rule_map.get(concept) + if rule.min_value is not None and rule.max_value is not None: # numeric rule stmnt = ( @@ -192,14 +248,25 @@ def _solve_rules(self) -> None: left_on="person_id", right_on=f"person_id_{i}", ) + # subqueries therefore contain the results for each group within the cohort definition. self.subqueries.append(main_df) + """ + This is the start of the process that begins to run the queries. + (1) call solve_rules that takes each group and adds those results to the sub_queries list + (2) this function then iterates through the list of groups to resolve the logic (AND/OR) between groups + """ def solve_query(self) -> int: - """Merge the groups and return the number of rows that matched all criteria.""" + #resolve within the group self._solve_rules() + merge_method = lambda x: "inner" if x == "AND" else "outer" + + #seed the dataframe with the first group0_df = self.subqueries[0] group0_df.rename({"person_id": "person_id_0"}, inplace=True, axis=1) + + #for the next, rename columns to give a unique key, then merge based on the merge_method value for i, df in enumerate(self.subqueries[1:], start=1): df.rename({"person_id": f"person_id_{i}"}, inplace=True, axis=1) group0_df = group0_df.merge( @@ -216,8 +283,9 @@ class BaseDistributionQuerySolver: def solve_query(self) -> Tuple[str, int]: raise NotImplementedError - +# class for distriubtion queries class CodeDistributionQuerySolver(BaseDistributionQuerySolver): + #todo - can the following be placed somewhere once as its repeated for all classes handling queries allowed_domains_map = { "Condition": ConditionOccurrence, "Ethnicity": Person, @@ -238,6 +306,8 @@ class CodeDistributionQuerySolver(BaseDistributionQuerySolver): "Observation": Observation.observation_concept_id, "Procedure": ProcedureOccurrence.procedure_concept_id, } + + # this one is unique for this resolver output_cols = [ "BIOBANK", "CODE", @@ -275,15 +345,23 @@ def solve_query(self) -> Tuple[str, int]: concepts = list() categories = list() biobanks = list() + + #todo - rename k, as this is a domain id that is being used for k in self.allowed_domains_map: + + # get the right table and column based on the domain table = self.allowed_domains_map[k] concept_col = self.domain_concept_id_map[k] + + # gets a list of all concepts within this given table and their respective counts stmnt = select(func.count(table.person_id), concept_col).group_by( concept_col ) res = pd.read_sql(stmnt, self.db_manager.engine.connect()) counts.extend(res.iloc[:, 0]) concepts.extend(res.iloc[:, 1]) + + # add the same category and collection if, for the number of results received categories.extend([k] * len(res)) biobanks.extend([self.query.collection] * len(res)) @@ -294,6 +372,7 @@ def solve_query(self) -> Tuple[str, int]: df["BIOBANK"] = biobanks # Get descriptions + #todo - not sure why this can be included in the SQL output above, it would need a join to the concept table concept_query = select(Concept.concept_id, Concept.concept_name).where( Concept.concept_id.in_(concepts) ) @@ -310,7 +389,7 @@ def solve_query(self) -> Tuple[str, int]: return os.linesep.join(results), len(df) - +#todo - i *think* the only diference between this one and generic is that the allowed_domain list is different. Could we not just have the one class and functions that have this passed in? class DemographicsDistributionQuerySolver(BaseDistributionQuerySolver): allowed_domains_map = { "Gender": Person, From 88e279883a4e65e37748d3d3f5c315f300f192f4 Mon Sep 17 00:00:00 2001 From: Andy Rae Date: Thu, 23 Jan 2025 19:07:03 +0000 Subject: [PATCH 2/4] Add type stubs (#51) * Add typed libraries * Fix tests directory * Update readme --- .vscode/settings.json | 2 +- README.md | 4 +-- pyproject.toml | 8 ++++- uv.lock | 68 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index fd9f527..1be6a53 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,6 @@ { "python.testing.pytestArgs": [ - "test" + "tests" ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true diff --git a/README.md b/README.md index b77ab58..03c334f 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,9 @@ |-|-|-| | ![Python][python-badge] | [![Bunny Docker Images][docker-badge]][bunny-containers] | [![Bunny Docs][docs-badge]][bunny-docs] | -An HDR UK Cohort Discovery Task Resolver. +A Cohort Discovery Task Resolver. -Fetches and resolves Availability and Distribution Queries against an OMOP-CDM database. +Fetches and resolves Availability and Distribution Queries against an OMOP CDM database. [hutch-logo]: https://raw.githubusercontent.com/HDRUK/hutch/main/assets/Hutch%20splash%20bg.svg [hutch-repo]: https://github.com/health-informatics-uon/hutch diff --git a/pyproject.toml b/pyproject.toml index fc1f535..7d5597f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,4 +38,10 @@ requires = ["hatchling"] build-backend = "hatchling.build" [dependency-groups] -dev = ["ruff>=0.8.6", "pytest>=8.3.4"] +dev = [ + "ruff>=0.8.6", + "pytest>=8.3.4", + "pandas-stubs>=2.2.3.241126", + "mypy>=1.14.1", + "types-requests>=2.32.0.20241016", +] diff --git a/uv.lock b/uv.lock index 00fb1ad..efc0c75 100644 --- a/uv.lock +++ b/uv.lock @@ -58,8 +58,11 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "mypy" }, + { name = "pandas-stubs" }, { name = "pytest" }, { name = "ruff" }, + { name = "types-requests" }, ] [package.metadata] @@ -76,8 +79,11 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "mypy", specifier = ">=1.14.1" }, + { name = "pandas-stubs", specifier = ">=2.2.3.241126" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "ruff", specifier = ">=0.8.6" }, + { name = "types-requests", specifier = ">=2.32.0.20241016" }, ] [[package]] @@ -98,6 +104,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, ] +[[package]] +name = "mypy" +version = "1.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/eb/2c92d8ea1e684440f54fa49ac5d9a5f19967b7b472a281f419e69a8d228e/mypy-1.14.1.tar.gz", hash = "sha256:7ec88144fe9b510e8475ec2f5f251992690fcf89ccb4500b214b4226abcd32d6", size = 3216051 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/15/bb6a686901f59222275ab228453de741185f9d54fecbaacec041679496c6/mypy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:92c3ed5afb06c3a8e188cb5da4984cab9ec9a77ba956ee419c68a388b4595255", size = 11252097 }, + { url = "https://files.pythonhosted.org/packages/f8/b3/8b0f74dfd072c802b7fa368829defdf3ee1566ba74c32a2cb2403f68024c/mypy-1.14.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:dbec574648b3e25f43d23577309b16534431db4ddc09fda50841f1e34e64ed34", size = 10239728 }, + { url = "https://files.pythonhosted.org/packages/c5/9b/4fd95ab20c52bb5b8c03cc49169be5905d931de17edfe4d9d2986800b52e/mypy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8c6d94b16d62eb3e947281aa7347d78236688e21081f11de976376cf010eb31a", size = 11924965 }, + { url = "https://files.pythonhosted.org/packages/56/9d/4a236b9c57f5d8f08ed346914b3f091a62dd7e19336b2b2a0d85485f82ff/mypy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d4b19b03fdf54f3c5b2fa474c56b4c13c9dbfb9a2db4370ede7ec11a2c5927d9", size = 12867660 }, + { url = "https://files.pythonhosted.org/packages/40/88/a61a5497e2f68d9027de2bb139c7bb9abaeb1be1584649fa9d807f80a338/mypy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0c911fde686394753fff899c409fd4e16e9b294c24bfd5e1ea4675deae1ac6fd", size = 12969198 }, + { url = "https://files.pythonhosted.org/packages/54/da/3d6fc5d92d324701b0c23fb413c853892bfe0e1dbe06c9138037d459756b/mypy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:8b21525cb51671219f5307be85f7e646a153e5acc656e5cebf64bfa076c50107", size = 9885276 }, + { url = "https://files.pythonhosted.org/packages/a0/b5/32dd67b69a16d088e533962e5044e51004176a9952419de0370cdaead0f8/mypy-1.14.1-py3-none-any.whl", hash = "sha256:b66a60cc4073aeb8ae00057f9c1f64d49e90f918fbcef9a977eb121da8b8f1d1", size = 2752905 }, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, +] + [[package]] name = "numpy" version = "2.2.1" @@ -162,6 +196,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436 }, ] +[[package]] +name = "pandas-stubs" +version = "2.2.3.241126" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "types-pytz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/86/93c545d149c3e1fe1c4c55478cc3a69859d0ea3467e1d9892e9eb28cb1e7/pandas_stubs-2.2.3.241126.tar.gz", hash = "sha256:cf819383c6d9ae7d4dabf34cd47e1e45525bb2f312e6ad2939c2c204cb708acd", size = 104204 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/ab/ed42acf15bab2e86e5c49fad4aa038315233c4c2d22f41b49faa4d837516/pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267", size = 158280 }, +] + [[package]] name = "pluggy" version = "1.5.0" @@ -354,6 +401,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/aa/cce7a726e314fbeae9fed8ed5d6bd4af19796286d7ee82f120e3633da74c/trino-0.331.0-py3-none-any.whl", hash = "sha256:4f909e6c2966d23917e2538bc7f342d5dcc6e512102811fb1e53bdaf15bd49e3", size = 53771 }, ] +[[package]] +name = "types-pytz" +version = "2024.2.0.20241221" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/26/516311b02b5a215e721155fb65db8a965d061372e388d6125ebce8d674b0/types_pytz-2024.2.0.20241221.tar.gz", hash = "sha256:06d7cde9613e9f7504766a0554a270c369434b50e00975b3a4a0f6eed0f2c1a9", size = 10213 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/db/c92ca6920cccd9c2998b013601542e2ac5e59bc805bcff94c94ad254b7df/types_pytz-2024.2.0.20241221-py3-none-any.whl", hash = "sha256:8fc03195329c43637ed4f593663df721fef919b60a969066e22606edf0b53ad5", size = 10008 }, +] + +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fa/3c/4f2a430c01a22abd49a583b6b944173e39e7d01b688190a5618bd59a2e22/types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95", size = 18065 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/01/485b3026ff90e5190b5e24f1711522e06c79f4a56c8f4b95848ac072e20f/types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747", size = 15836 }, +] + [[package]] name = "typing-extensions" version = "4.12.2" From d20b2a34237763844d4501a77cab59e4df6ae962 Mon Sep 17 00:00:00 2001 From: Jon Couldridge Date: Fri, 24 Jan 2025 10:05:08 +0000 Subject: [PATCH 3/4] short db driver expansion now in a function and used by tests (#53) --- src/hutch_bunny/core/setting_database.py | 28 +++++++++++++------ tests/test_demographics_distribution_query.py | 5 +++- tests/test_return.py | 5 +++- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/hutch_bunny/core/setting_database.py b/src/hutch_bunny/core/setting_database.py index 229d5e9..26337da 100644 --- a/src/hutch_bunny/core/setting_database.py +++ b/src/hutch_bunny/core/setting_database.py @@ -4,6 +4,23 @@ import hutch_bunny.core.settings as settings +def expand_short_drivers(drivername: str): + """ + Expand unqualified "short" db driver names when necessary so we can override sqlalchemy + e.g. when using psycopg3, expand `postgresql` explicitly rather than use sqlalchemy's default of psycopg2 + """ + + if drivername == "postgresql": + return settings.DEFAULT_POSTGRES_DRIVER + + if drivername == "mssql": + return settings.DEFAULT_MSSQL_DRIVER + + # Add other explicit driver qualification as needed ... + + return drivername + + def setting_database(logger: Logger): logger.info("Setting up database connection...") @@ -24,17 +41,10 @@ def setting_database(logger: Logger): exit() else: datasource_db_port = environ.get("DATASOURCE_DB_PORT") - datasource_db_drivername = environ.get( - "DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER + datasource_db_drivername = expand_short_drivers( + environ.get("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER) ) - # expand postgres to a full default driver, so we can override sqlalchemy - if datasource_db_drivername == "postgresql": - datasource_db_drivername = settings.DEFAULT_POSTGRES_DRIVER - - if datasource_db_drivername == "mssql": - datasource_db_drivername = settings.DEFAULT_MSSQL_DRIVER - try: db_manager = SyncDBManager( username=environ.get("DATASOURCE_DB_USERNAME"), diff --git a/tests/test_demographics_distribution_query.py b/tests/test_demographics_distribution_query.py index a06f5a4..03da36b 100644 --- a/tests/test_demographics_distribution_query.py +++ b/tests/test_demographics_distribution_query.py @@ -6,6 +6,7 @@ from dotenv import load_dotenv import os import hutch_bunny.core.settings as settings +import hutch_bunny.core.setting_database as db_settings load_dotenv() @@ -23,7 +24,9 @@ def db_manager(): host=os.getenv("DATASOURCE_DB_HOST"), port=(int(datasource_db_port) if datasource_db_port is not None else None), database=os.getenv("DATASOURCE_DB_DATABASE"), - drivername=os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER), + drivername=db_settings.expand_short_drivers( + os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER) + ), schema=os.getenv("DATASOURCE_DB_SCHEMA"), ) diff --git a/tests/test_return.py b/tests/test_return.py index dbaa999..3594854 100644 --- a/tests/test_return.py +++ b/tests/test_return.py @@ -11,6 +11,7 @@ from dotenv import load_dotenv import os import hutch_bunny.core.settings as settings +import hutch_bunny.core.setting_database as db_settings load_dotenv() @@ -28,7 +29,9 @@ def db_manager(): host=os.getenv("DATASOURCE_DB_HOST"), port=(int(datasource_db_port) if datasource_db_port is not None else None), database=os.getenv("DATASOURCE_DB_DATABASE"), - drivername=os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER), + drivername=db_settings.expand_short_drivers( + os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER) + ), schema=os.getenv("DATASOURCE_DB_SCHEMA"), ) From b97525552e397177cd994c016956a9ca1d36e83b Mon Sep 17 00:00:00 2001 From: Andy Rae Date: Fri, 24 Jan 2025 10:10:49 +0000 Subject: [PATCH 4/4] Test obfuscation (#52) * Rename apply_filters to remove v2 * Delete unused get modifiers * Add obfuscation tests * Extract result_modifiers code from obfuscation * Remove union types * Add further tests --------- Co-authored-by: Jon Couldridge --- src/hutch_bunny/cli.py | 2 +- src/hutch_bunny/core/execute_query.py | 4 +- src/hutch_bunny/core/obfuscation.py | 85 ++--------------------- src/hutch_bunny/core/results_modifiers.py | 24 +++++++ tests/__init__.py | 0 tests/test_obfuscation.py | 66 ++++++++++++++++++ 6 files changed, 100 insertions(+), 81 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_obfuscation.py diff --git a/src/hutch_bunny/cli.py b/src/hutch_bunny/cli.py index 44d21de..9ee9194 100644 --- a/src/hutch_bunny/cli.py +++ b/src/hutch_bunny/cli.py @@ -1,5 +1,5 @@ import json -from hutch_bunny.core.obfuscation import get_results_modifiers_from_str +from hutch_bunny.core.results_modifiers import get_results_modifiers_from_str from hutch_bunny.core.execute_query import execute_query from hutch_bunny.core.rquest_dto.result import RquestResult from hutch_bunny.core.parser import parser diff --git a/src/hutch_bunny/core/execute_query.py b/src/hutch_bunny/core/execute_query.py index 54f62fa..f98e20e 100644 --- a/src/hutch_bunny/core/execute_query.py +++ b/src/hutch_bunny/core/execute_query.py @@ -3,7 +3,7 @@ from hutch_bunny.core import query_solvers from hutch_bunny.core.rquest_dto.query import AvailabilityQuery, DistributionQuery from hutch_bunny.core.obfuscation import ( - apply_filters_v2, + apply_filters, ) from hutch_bunny.core.rquest_dto.result import RquestResult @@ -50,7 +50,7 @@ def execute_query( result = query_solvers.solve_availability( db_manager=db_manager, query=query ) - result.count = apply_filters_v2(result.count, results_modifiers) + result.count = apply_filters(result.count, results_modifiers) return result except TypeError as te: # raised if the distribution query json format is wrong logger.error(str(te), exc_info=True) diff --git a/src/hutch_bunny/core/obfuscation.py b/src/hutch_bunny/core/obfuscation.py index 5d2795f..c98a50c 100644 --- a/src/hutch_bunny/core/obfuscation.py +++ b/src/hutch_bunny/core/obfuscation.py @@ -1,58 +1,8 @@ -import json -import os -import requests -from typing import Union - - -def get_results_modifiers(activity_source_id: int) -> list: - """Get the results modifiers for a given activity source. - - Args: - activity_source_id (int): The acivity source ID. - - Returns: - list: The modifiers for the given activity source. - - Raises: - HTTPError: raised when this function can't get the results modifiers. - """ - res = requests.get( - f"{os.getenv('MANAGER_URL')}/api/activitysources/{activity_source_id}/resultsmodifiers", - verify=int(os.getenv("MANAGER_VERIFY_SSL", 1)), - ) - res.raise_for_status() - modifiers = res.json() - return modifiers - - -def get_results_modifiers_from_str(params: str) -> list: - """Deserialise a JSON list containing results modifiers - - Args: - params (str): - The JSON string containing list of parameter objects for results modifiers - - Raises: - ValueError: The parsed string does not produce a list - - Returns: - list: The list of parameter dicts of results modifiers - """ - deserialised_params = json.loads(params) - if not isinstance(deserialised_params, list): - raise ValueError( - f"{get_results_modifiers_from_str.__name__} requires a JSON list" - ) - return deserialised_params - - -def low_number_suppression( - value: Union[int, float], threshold: int = 10 -) -> Union[int, float]: +def low_number_suppression(value: int | float, threshold: int = 10) -> int | float: """Suppress values that fall below a given threshold. Args: - value (Union[int, float]): The value to evaluate. + value (int | float): The value to evaluate. threshold (int): The threshold to beat. Returns: @@ -67,11 +17,11 @@ def low_number_suppression( return value if value > threshold else 0 -def rounding(value: Union[int, float], nearest: int = 10) -> int: +def rounding(value: int | float, nearest: int = 10) -> int: """Round the value to the nearest base number, e.g. 10. Args: - value (Union[int, float]): The value to be rounded + value (int | float): The value to be rounded nearest (int, optional): Round value to this base. Defaults to 10. Returns: @@ -86,36 +36,15 @@ def rounding(value: Union[int, float], nearest: int = 10) -> int: return nearest * round(value / nearest) -def apply_filters(value: Union[int, float], filters: list) -> Union[int, float]: - """Iterate over a list of filters from the Manager and apply them to the - supplied value. - - Args: - value (Union[int, float]): The value to be filtered. - filters (list): The filters applied to the value. - - Returns: - Union[int, float]: The filtered value. - """ - actions = {"Low Number Suppression": low_number_suppression, "Rounding": rounding} - result = value - for f in filters: - if action := actions.get(f["type"]["id"]): - result = action(result, **f["parameters"]) - if result == 0: - break # don't apply any more filters - return result - - -def apply_filters_v2(value: Union[int, float], filters: list) -> Union[int, float]: +def apply_filters(value: int | float, filters: list) -> int | float: """Iterate over a list of filters and apply them to the supplied value. Args: - value (Union[int, float]): The value to be filtered. + value (int | float): The value to be filtered. filters (list): The filters applied to the value. Returns: - Union[int, float]: The filtered value. + int | float: The filtered value. """ actions = {"Low Number Suppression": low_number_suppression, "Rounding": rounding} result = value diff --git a/src/hutch_bunny/core/results_modifiers.py b/src/hutch_bunny/core/results_modifiers.py index 76b6e17..52bac71 100644 --- a/src/hutch_bunny/core/results_modifiers.py +++ b/src/hutch_bunny/core/results_modifiers.py @@ -1,3 +1,6 @@ +import json + + def results_modifiers( low_number_suppression_threshold: int, rounding_target: int, @@ -18,3 +21,24 @@ def results_modifiers( } ) return results_modifiers + + +def get_results_modifiers_from_str(params: str) -> list: + """Deserialise a JSON list containing results modifiers + + Args: + params (str): + The JSON string containing list of parameter objects for results modifiers + + Raises: + ValueError: The parsed string does not produce a list + + Returns: + list: The list of parameter dicts of results modifiers + """ + deserialised_params = json.loads(params) + if not isinstance(deserialised_params, list): + raise ValueError( + f"{get_results_modifiers_from_str.__name__} requires a JSON list" + ) + return deserialised_params diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_obfuscation.py b/tests/test_obfuscation.py new file mode 100644 index 0000000..edd72b2 --- /dev/null +++ b/tests/test_obfuscation.py @@ -0,0 +1,66 @@ +from hutch_bunny.core.obfuscation import ( + apply_filters, + low_number_suppression, + rounding, +) + + +def test_low_number_suppression(): + # Test that the threshold is applied + assert low_number_suppression(99, threshold=100) == 0 + assert low_number_suppression(100, threshold=100) == 0 + assert low_number_suppression(101, threshold=100) == 101 + + # Test that the threshold can be set to 0 + assert low_number_suppression(1, threshold=0) == 1 + + # Test negative threshold is ignored + assert low_number_suppression(1, threshold=-5) == 1 + + +def test_rounding(): + # Test default nearest + assert rounding(9) == 10 + + # Test rounding is applied + assert rounding(123, nearest=100) == 100 + assert rounding(123, nearest=10) == 120 + assert rounding(123, nearest=1) == 123 + + # Test rounding is applied the boundary + assert rounding(150, nearest=100) == 200 + + +def test_apply_filters_rounding(): + # Test rounding only + filters = [{"id": "Rounding", "nearest": 100}] + assert apply_filters(123, filters=filters) == 100 + + +def test_apply_filters_low_number_suppression(): + # Test low number suppression only + filters = [{"id": "Low Number Suppression", "threshold": 100}] + assert apply_filters(123, filters=filters) == 123 + + +def test_apply_filters_combined(): + # Test both filters + filters = [ + {"id": "Low Number Suppression", "threshold": 100}, + {"id": "Rounding", "nearest": 100}, + ] + assert apply_filters(123, filters=filters) == 100 + + +def test_apply_filters_combined_leak(): + # Test that putting the rounding filter first can leak the low number suppression filter + filters = [ + {"id": "Rounding", "nearest": 100}, + {"id": "Low Number Suppression", "threshold": 70}, + ] + assert apply_filters(60, filters=filters) == 100 + + +def test_apply_filters_combined_empty_filter(): + # Test that an empty filter list returns the original value + assert apply_filters(9, []) == 9