Skip to content

Commit

Permalink
Merge branch 'main' into dist_queries
Browse files Browse the repository at this point in the history
  • Loading branch information
vpnu committed Jan 24, 2025
2 parents 8d42a8d + b975255 commit 92473b9
Show file tree
Hide file tree
Showing 14 changed files with 289 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"python.testing.pytestArgs": [
"test"
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
|-|-|-|
| ![Python][python-badge] | [![Bunny Docker Images][docker-badge]][bunny-containers] | [![Bunny Docs][docs-badge]][bunny-docs] |

An HDR UK Cohort Discovery Task Resolver.
A Cohort Discovery Task Resolver.

Fetches and resolves Availability and Distribution Queries against an OMOP-CDM database.
Fetches and resolves Availability and Distribution Queries against an OMOP CDM database.

[hutch-logo]: https://raw.githubusercontent.com/HDRUK/hutch/main/assets/Hutch%20splash%20bg.svg
[hutch-repo]: https://github.com/health-informatics-uon/hutch
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,10 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[dependency-groups]
dev = ["ruff>=0.8.6", "pytest>=8.3.4"]
dev = [
"ruff>=0.8.6",
"pytest>=8.3.4",
"pandas-stubs>=2.2.3.241126",
"mypy>=1.14.1",
"types-requests>=2.32.0.20241016",
]
2 changes: 1 addition & 1 deletion src/hutch_bunny/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from hutch_bunny.core.obfuscation import get_results_modifiers_from_str
from hutch_bunny.core.results_modifiers import get_results_modifiers_from_str
from hutch_bunny.core.execute_query import execute_query
from hutch_bunny.core.rquest_dto.result import RquestResult
from hutch_bunny.core.parser import parser
Expand Down
4 changes: 2 additions & 2 deletions src/hutch_bunny/core/execute_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hutch_bunny.core import query_solvers
from hutch_bunny.core.rquest_dto.query import AvailabilityQuery, DistributionQuery
from hutch_bunny.core.obfuscation import (
apply_filters_v2,
apply_filters,
)
from hutch_bunny.core.rquest_dto.result import RquestResult

Expand Down Expand Up @@ -50,7 +50,7 @@ def execute_query(
result = query_solvers.solve_availability(
db_manager=db_manager, query=query
)
result.count = apply_filters_v2(result.count, results_modifiers)
result.count = apply_filters(result.count, results_modifiers)
return result
except TypeError as te: # raised if the distribution query json format is wrong
logger.error(str(te), exc_info=True)
Expand Down
85 changes: 7 additions & 78 deletions src/hutch_bunny/core/obfuscation.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,8 @@
import json
import os
import requests
from typing import Union


def get_results_modifiers(activity_source_id: int) -> list:
"""Get the results modifiers for a given activity source.
Args:
activity_source_id (int): The acivity source ID.
Returns:
list: The modifiers for the given activity source.
Raises:
HTTPError: raised when this function can't get the results modifiers.
"""
res = requests.get(
f"{os.getenv('MANAGER_URL')}/api/activitysources/{activity_source_id}/resultsmodifiers",
verify=int(os.getenv("MANAGER_VERIFY_SSL", 1)),
)
res.raise_for_status()
modifiers = res.json()
return modifiers


def get_results_modifiers_from_str(params: str) -> list:
"""Deserialise a JSON list containing results modifiers
Args:
params (str):
The JSON string containing list of parameter objects for results modifiers
Raises:
ValueError: The parsed string does not produce a list
Returns:
list: The list of parameter dicts of results modifiers
"""
deserialised_params = json.loads(params)
if not isinstance(deserialised_params, list):
raise ValueError(
f"{get_results_modifiers_from_str.__name__} requires a JSON list"
)
return deserialised_params


def low_number_suppression(
value: Union[int, float], threshold: int = 10
) -> Union[int, float]:
def low_number_suppression(value: int | float, threshold: int = 10) -> int | float:
"""Suppress values that fall below a given threshold.
Args:
value (Union[int, float]): The value to evaluate.
value (int | float): The value to evaluate.
threshold (int): The threshold to beat.
Returns:
Expand All @@ -67,11 +17,11 @@ def low_number_suppression(
return value if value > threshold else 0


def rounding(value: Union[int, float], nearest: int = 10) -> int:
def rounding(value: int | float, nearest: int = 10) -> int:
"""Round the value to the nearest base number, e.g. 10.
Args:
value (Union[int, float]): The value to be rounded
value (int | float): The value to be rounded
nearest (int, optional): Round value to this base. Defaults to 10.
Returns:
Expand All @@ -86,36 +36,15 @@ def rounding(value: Union[int, float], nearest: int = 10) -> int:
return nearest * round(value / nearest)


def apply_filters(value: Union[int, float], filters: list) -> Union[int, float]:
"""Iterate over a list of filters from the Manager and apply them to the
supplied value.
Args:
value (Union[int, float]): The value to be filtered.
filters (list): The filters applied to the value.
Returns:
Union[int, float]: The filtered value.
"""
actions = {"Low Number Suppression": low_number_suppression, "Rounding": rounding}
result = value
for f in filters:
if action := actions.get(f["type"]["id"]):
result = action(result, **f["parameters"])
if result == 0:
break # don't apply any more filters
return result


def apply_filters_v2(value: Union[int, float], filters: list) -> Union[int, float]:
def apply_filters(value: int | float, filters: list) -> int | float:
"""Iterate over a list of filters and apply them to the supplied value.
Args:
value (Union[int, float]): The value to be filtered.
value (int | float): The value to be filtered.
filters (list): The filters applied to the value.
Returns:
Union[int, float]: The filtered value.
int | float: The filtered value.
"""
actions = {"Low Number Suppression": low_number_suppression, "Rounding": rounding}
result = value
Expand Down
89 changes: 84 additions & 5 deletions src/hutch_bunny/core/query_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import hutch_bunny.core.settings as settings
from hutch_bunny.core.constants import DISTRIBUTION_TYPE_FILE_NAMES_MAP


# Class for availability queries
class AvailibilityQuerySolver:
subqueries = list()
concept_table_map = {
Expand Down Expand Up @@ -67,11 +67,25 @@ def __init__(self, db_manager: SyncDBManager, query: AvailabilityQuery) -> None:
self.db_manager = db_manager
self.query = query

""" Function that takes all the concept IDs in the cohort defintion, looks them up in the OMOP database
to extract the concept_id and domain and place this within a dictionary for lookup during other query building
Although the query payload will tell you where the OMOP concept is from (based on the RQUEST OMOP version, this is
a safer method as we know concepts can move between tables based on a vocab.
Therefore this helps to account for a difference between the Bunny vocab version and the RQUEST OMOP version.
#TODO: this does not cover the scenario that is possible to occur where the local vocab model may say the concept
should be based in one table but it is actually present in another
"""
def _find_concepts(self) -> dict:
concept_ids = set()
for group in self.query.cohort.groups:
for rule in group.rules:
concept_ids.add(int(rule.value))

concept_query = (
# order must be .concept_id, .domain_id
select(Concept.concept_id, Concept.domain_id)
Expand All @@ -86,15 +100,39 @@ def _find_concepts(self) -> dict:
}
return concept_dict

""" Function for taking the JSON query from RQUEST and creating the required query to run against the OMOP database.
RQUEST API spec can have multiple groups in each query, and then a condition between the groups.
Each group can have conditional logic AND/OR within the group
Each concept can either be an inclusion or exclusion criteria.
Each concept can have an age set, so it is that this event with concept X occurred when
the person was between a certain age. - #TODO - not sure this is implemented here
"""
def _solve_rules(self) -> None:
"""Find all rows that match the rules' criteria."""

#get the list of concepts to build the query constraints
concepts = self._find_concepts()

# This is related to the logic within a group. This is used in the subsequent for loop to determine how
# the merge should be applied.
merge_method = lambda x: "inner" if x == "AND" else "outer"

# iterate through all the groups specified in the query
for group in self.query.cohort.groups:

# todo - refactor variable name concept as this is misleading. It is not the concept but actually the domain of the concept
# this passes in the conceptID of but gets back the domain related to that concept.
concept = concepts.get(group.rules[0].value)

concept_table = self.concept_table_map.get(concept)
boolean_rule_col = self.boolean_rule_map.get(concept)
numeric_rule_col = self.numeric_rule_map.get(concept)

#within the query, if a range was specified, which is currently
if (
group.rules[0].min_value is not None
and group.rules[0].max_value is not None
Expand All @@ -114,6 +152,10 @@ def _solve_rules(self) -> None:
main_df = pd.read_sql_query(
sql=stmnt, con=self.db_manager.engine.connect()
)

# the next two ifs are basically switching between equals and not equals. These could be merged with a simple
# switch for the operator.

elif group.rules[0].operator == "=":
stmnt = (
select(concept_table.person_id)
Expand All @@ -132,11 +174,25 @@ def _solve_rules(self) -> None:
main_df = pd.read_sql_query(
sql=stmnt, con=self.db_manager.engine.connect()
)

"""
Now that the main_df dataframe has been populated, the subsequent queries are created and merged into
main_df dataframe. That is why above the first concept is hard coded as accessing index 0 and why the for
loop below if start at index 1. The queries are almost identical to the above, exact same logic but
in order to facilitate the merging, a label is created on person id, so that the newly created data frame
can be merged with main_df via unique keys.
"""

for i, rule in enumerate(group.rules[1:], start=1):

# todo - refactor variable name concept as this is misleading. It is not the concept but actually the domain of the concept
# this passes in the conceptID of but gets back the domain related to that concept.
concept = concepts.get(rule.value)

concept_table = self.concept_table_map.get(concept)
boolean_rule_col = self.boolean_rule_map.get(concept)
numeric_rule_col = self.numeric_rule_map.get(concept)

if rule.min_value is not None and rule.max_value is not None:
# numeric rule
stmnt = (
Expand Down Expand Up @@ -192,14 +248,25 @@ def _solve_rules(self) -> None:
left_on="person_id",
right_on=f"person_id_{i}",
)
# subqueries therefore contain the results for each group within the cohort definition.
self.subqueries.append(main_df)

"""
This is the start of the process that begins to run the queries.
(1) call solve_rules that takes each group and adds those results to the sub_queries list
(2) this function then iterates through the list of groups to resolve the logic (AND/OR) between groups
"""
def solve_query(self) -> int:
"""Merge the groups and return the number of rows that matched all criteria."""
#resolve within the group
self._solve_rules()

merge_method = lambda x: "inner" if x == "AND" else "outer"

#seed the dataframe with the first
group0_df = self.subqueries[0]
group0_df.rename({"person_id": "person_id_0"}, inplace=True, axis=1)

#for the next, rename columns to give a unique key, then merge based on the merge_method value
for i, df in enumerate(self.subqueries[1:], start=1):
df.rename({"person_id": f"person_id_{i}"}, inplace=True, axis=1)
group0_df = group0_df.merge(
Expand All @@ -216,8 +283,9 @@ class BaseDistributionQuerySolver:
def solve_query(self) -> Tuple[str, int]:
raise NotImplementedError


# class for distriubtion queries
class CodeDistributionQuerySolver(BaseDistributionQuerySolver):
#todo - can the following be placed somewhere once as its repeated for all classes handling queries
allowed_domains_map = {
"Condition": ConditionOccurrence,
"Ethnicity": Person,
Expand All @@ -238,6 +306,8 @@ class CodeDistributionQuerySolver(BaseDistributionQuerySolver):
"Observation": Observation.observation_concept_id,
"Procedure": ProcedureOccurrence.procedure_concept_id,
}

# this one is unique for this resolver
output_cols = [
"BIOBANK",
"CODE",
Expand Down Expand Up @@ -275,15 +345,23 @@ def solve_query(self) -> Tuple[str, int]:
concepts = list()
categories = list()
biobanks = list()

#todo - rename k, as this is a domain id that is being used
for k in self.allowed_domains_map:

# get the right table and column based on the domain
table = self.allowed_domains_map[k]
concept_col = self.domain_concept_id_map[k]

# gets a list of all concepts within this given table and their respective counts
stmnt = select(func.count(table.person_id), concept_col).group_by(
concept_col
)
res = pd.read_sql(stmnt, self.db_manager.engine.connect())
counts.extend(res.iloc[:, 0])
concepts.extend(res.iloc[:, 1])

# add the same category and collection if, for the number of results received
categories.extend([k] * len(res))
biobanks.extend([self.query.collection] * len(res))

Expand All @@ -294,6 +372,7 @@ def solve_query(self) -> Tuple[str, int]:
df["BIOBANK"] = biobanks

# Get descriptions
#todo - not sure why this can be included in the SQL output above, it would need a join to the concept table
concept_query = select(Concept.concept_id, Concept.concept_name).where(
Concept.concept_id.in_(concepts)
)
Expand All @@ -311,7 +390,7 @@ def solve_query(self) -> Tuple[str, int]:

return os.linesep.join(results), len(df)


#todo - i *think* the only diference between this one and generic is that the allowed_domain list is different. Could we not just have the one class and functions that have this passed in?
class DemographicsDistributionQuerySolver(BaseDistributionQuerySolver):
allowed_domains_map = {
"Gender": Person,
Expand Down
Loading

0 comments on commit 92473b9

Please sign in to comment.