Skip to content

Commit

Permalink
Documenting steps to help with #31 (#44)
Browse files Browse the repository at this point in the history
* Adding my understanding of process and functions via comments

* Documented for the availability query, some todos and refactors.

Code can be heavily optimised.

* Comments added for generic distro
  • Loading branch information
prquinlan authored Jan 23, 2025
1 parent 666caea commit 44a2534
Showing 1 changed file with 84 additions and 5 deletions.
89 changes: 84 additions & 5 deletions src/hutch_bunny/core/query_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import hutch_bunny.core.settings as settings
from hutch_bunny.core.constants import DISTRIBUTION_TYPE_FILE_NAMES_MAP


# Class for availability queries
class AvailibilityQuerySolver:
subqueries = list()
concept_table_map = {
Expand Down Expand Up @@ -67,11 +67,25 @@ def __init__(self, db_manager: SyncDBManager, query: AvailabilityQuery) -> None:
self.db_manager = db_manager
self.query = query

""" Function that takes all the concept IDs in the cohort defintion, looks them up in the OMOP database
to extract the concept_id and domain and place this within a dictionary for lookup during other query building
Although the query payload will tell you where the OMOP concept is from (based on the RQUEST OMOP version, this is
a safer method as we know concepts can move between tables based on a vocab.
Therefore this helps to account for a difference between the Bunny vocab version and the RQUEST OMOP version.
#TODO: this does not cover the scenario that is possible to occur where the local vocab model may say the concept
should be based in one table but it is actually present in another
"""
def _find_concepts(self) -> dict:
concept_ids = set()
for group in self.query.cohort.groups:
for rule in group.rules:
concept_ids.add(int(rule.value))

concept_query = (
# order must be .concept_id, .domain_id
select(Concept.concept_id, Concept.domain_id)
Expand All @@ -86,15 +100,39 @@ def _find_concepts(self) -> dict:
}
return concept_dict

""" Function for taking the JSON query from RQUEST and creating the required query to run against the OMOP database.
RQUEST API spec can have multiple groups in each query, and then a condition between the groups.
Each group can have conditional logic AND/OR within the group
Each concept can either be an inclusion or exclusion criteria.
Each concept can have an age set, so it is that this event with concept X occurred when
the person was between a certain age. - #TODO - not sure this is implemented here
"""
def _solve_rules(self) -> None:
"""Find all rows that match the rules' criteria."""

#get the list of concepts to build the query constraints
concepts = self._find_concepts()

# This is related to the logic within a group. This is used in the subsequent for loop to determine how
# the merge should be applied.
merge_method = lambda x: "inner" if x == "AND" else "outer"

# iterate through all the groups specified in the query
for group in self.query.cohort.groups:

# todo - refactor variable name concept as this is misleading. It is not the concept but actually the domain of the concept
# this passes in the conceptID of but gets back the domain related to that concept.
concept = concepts.get(group.rules[0].value)

concept_table = self.concept_table_map.get(concept)
boolean_rule_col = self.boolean_rule_map.get(concept)
numeric_rule_col = self.numeric_rule_map.get(concept)

#within the query, if a range was specified, which is currently
if (
group.rules[0].min_value is not None
and group.rules[0].max_value is not None
Expand All @@ -114,6 +152,10 @@ def _solve_rules(self) -> None:
main_df = pd.read_sql_query(
sql=stmnt, con=self.db_manager.engine.connect()
)

# the next two ifs are basically switching between equals and not equals. These could be merged with a simple
# switch for the operator.

elif group.rules[0].operator == "=":
stmnt = (
select(concept_table.person_id)
Expand All @@ -132,11 +174,25 @@ def _solve_rules(self) -> None:
main_df = pd.read_sql_query(
sql=stmnt, con=self.db_manager.engine.connect()
)

"""
Now that the main_df dataframe has been populated, the subsequent queries are created and merged into
main_df dataframe. That is why above the first concept is hard coded as accessing index 0 and why the for
loop below if start at index 1. The queries are almost identical to the above, exact same logic but
in order to facilitate the merging, a label is created on person id, so that the newly created data frame
can be merged with main_df via unique keys.
"""

for i, rule in enumerate(group.rules[1:], start=1):

# todo - refactor variable name concept as this is misleading. It is not the concept but actually the domain of the concept
# this passes in the conceptID of but gets back the domain related to that concept.
concept = concepts.get(rule.value)

concept_table = self.concept_table_map.get(concept)
boolean_rule_col = self.boolean_rule_map.get(concept)
numeric_rule_col = self.numeric_rule_map.get(concept)

if rule.min_value is not None and rule.max_value is not None:
# numeric rule
stmnt = (
Expand Down Expand Up @@ -192,14 +248,25 @@ def _solve_rules(self) -> None:
left_on="person_id",
right_on=f"person_id_{i}",
)
# subqueries therefore contain the results for each group within the cohort definition.
self.subqueries.append(main_df)

"""
This is the start of the process that begins to run the queries.
(1) call solve_rules that takes each group and adds those results to the sub_queries list
(2) this function then iterates through the list of groups to resolve the logic (AND/OR) between groups
"""
def solve_query(self) -> int:
"""Merge the groups and return the number of rows that matched all criteria."""
#resolve within the group
self._solve_rules()

merge_method = lambda x: "inner" if x == "AND" else "outer"

#seed the dataframe with the first
group0_df = self.subqueries[0]
group0_df.rename({"person_id": "person_id_0"}, inplace=True, axis=1)

#for the next, rename columns to give a unique key, then merge based on the merge_method value
for i, df in enumerate(self.subqueries[1:], start=1):
df.rename({"person_id": f"person_id_{i}"}, inplace=True, axis=1)
group0_df = group0_df.merge(
Expand All @@ -216,8 +283,9 @@ class BaseDistributionQuerySolver:
def solve_query(self) -> Tuple[str, int]:
raise NotImplementedError


# class for distriubtion queries
class CodeDistributionQuerySolver(BaseDistributionQuerySolver):
#todo - can the following be placed somewhere once as its repeated for all classes handling queries
allowed_domains_map = {
"Condition": ConditionOccurrence,
"Ethnicity": Person,
Expand All @@ -238,6 +306,8 @@ class CodeDistributionQuerySolver(BaseDistributionQuerySolver):
"Observation": Observation.observation_concept_id,
"Procedure": ProcedureOccurrence.procedure_concept_id,
}

# this one is unique for this resolver
output_cols = [
"BIOBANK",
"CODE",
Expand Down Expand Up @@ -275,15 +345,23 @@ def solve_query(self) -> Tuple[str, int]:
concepts = list()
categories = list()
biobanks = list()

#todo - rename k, as this is a domain id that is being used
for k in self.allowed_domains_map:

# get the right table and column based on the domain
table = self.allowed_domains_map[k]
concept_col = self.domain_concept_id_map[k]

# gets a list of all concepts within this given table and their respective counts
stmnt = select(func.count(table.person_id), concept_col).group_by(
concept_col
)
res = pd.read_sql(stmnt, self.db_manager.engine.connect())
counts.extend(res.iloc[:, 0])
concepts.extend(res.iloc[:, 1])

# add the same category and collection if, for the number of results received
categories.extend([k] * len(res))
biobanks.extend([self.query.collection] * len(res))

Expand All @@ -294,6 +372,7 @@ def solve_query(self) -> Tuple[str, int]:
df["BIOBANK"] = biobanks

# Get descriptions
#todo - not sure why this can be included in the SQL output above, it would need a join to the concept table
concept_query = select(Concept.concept_id, Concept.concept_name).where(
Concept.concept_id.in_(concepts)
)
Expand All @@ -310,7 +389,7 @@ def solve_query(self) -> Tuple[str, int]:

return os.linesep.join(results), len(df)


#todo - i *think* the only diference between this one and generic is that the allowed_domain list is different. Could we not just have the one class and functions that have this passed in?
class DemographicsDistributionQuerySolver(BaseDistributionQuerySolver):
allowed_domains_map = {
"Gender": Person,
Expand Down

0 comments on commit 44a2534

Please sign in to comment.