Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test obfuscation #52

Merged
merged 7 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/hutch_bunny/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from hutch_bunny.core.obfuscation import get_results_modifiers_from_str
from hutch_bunny.core.results_modifiers import get_results_modifiers_from_str
from hutch_bunny.core.execute_query import execute_query
from hutch_bunny.core.rquest_dto.result import RquestResult
from hutch_bunny.core.parser import parser
Expand Down
4 changes: 2 additions & 2 deletions src/hutch_bunny/core/execute_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hutch_bunny.core import query_solvers
from hutch_bunny.core.rquest_dto.query import AvailabilityQuery, DistributionQuery
from hutch_bunny.core.obfuscation import (
apply_filters_v2,
apply_filters,
)
from hutch_bunny.core.rquest_dto.result import RquestResult

Expand Down Expand Up @@ -50,7 +50,7 @@ def execute_query(
result = query_solvers.solve_availability(
db_manager=db_manager, query=query
)
result.count = apply_filters_v2(result.count, results_modifiers)
result.count = apply_filters(result.count, results_modifiers)
return result
except TypeError as te: # raised if the distribution query json format is wrong
logger.error(str(te), exc_info=True)
Expand Down
85 changes: 7 additions & 78 deletions src/hutch_bunny/core/obfuscation.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,8 @@
import json
import os
import requests
from typing import Union


def get_results_modifiers(activity_source_id: int) -> list:
"""Get the results modifiers for a given activity source.

Args:
activity_source_id (int): The acivity source ID.

Returns:
list: The modifiers for the given activity source.

Raises:
HTTPError: raised when this function can't get the results modifiers.
"""
res = requests.get(
f"{os.getenv('MANAGER_URL')}/api/activitysources/{activity_source_id}/resultsmodifiers",
verify=int(os.getenv("MANAGER_VERIFY_SSL", 1)),
)
res.raise_for_status()
modifiers = res.json()
return modifiers


def get_results_modifiers_from_str(params: str) -> list:
"""Deserialise a JSON list containing results modifiers

Args:
params (str):
The JSON string containing list of parameter objects for results modifiers

Raises:
ValueError: The parsed string does not produce a list

Returns:
list: The list of parameter dicts of results modifiers
"""
deserialised_params = json.loads(params)
if not isinstance(deserialised_params, list):
raise ValueError(
f"{get_results_modifiers_from_str.__name__} requires a JSON list"
)
return deserialised_params


def low_number_suppression(
value: Union[int, float], threshold: int = 10
) -> Union[int, float]:
def low_number_suppression(value: int | float, threshold: int = 10) -> int | float:
"""Suppress values that fall below a given threshold.

Args:
value (Union[int, float]): The value to evaluate.
value (int | float): The value to evaluate.
threshold (int): The threshold to beat.

Returns:
Expand All @@ -67,11 +17,11 @@ def low_number_suppression(
return value if value > threshold else 0


def rounding(value: Union[int, float], nearest: int = 10) -> int:
def rounding(value: int | float, nearest: int = 10) -> int:
"""Round the value to the nearest base number, e.g. 10.

Args:
value (Union[int, float]): The value to be rounded
value (int | float): The value to be rounded
nearest (int, optional): Round value to this base. Defaults to 10.

Returns:
Expand All @@ -86,36 +36,15 @@ def rounding(value: Union[int, float], nearest: int = 10) -> int:
return nearest * round(value / nearest)


def apply_filters(value: Union[int, float], filters: list) -> Union[int, float]:
"""Iterate over a list of filters from the Manager and apply them to the
supplied value.

Args:
value (Union[int, float]): The value to be filtered.
filters (list): The filters applied to the value.

Returns:
Union[int, float]: The filtered value.
"""
actions = {"Low Number Suppression": low_number_suppression, "Rounding": rounding}
result = value
for f in filters:
if action := actions.get(f["type"]["id"]):
result = action(result, **f["parameters"])
if result == 0:
break # don't apply any more filters
return result


def apply_filters_v2(value: Union[int, float], filters: list) -> Union[int, float]:
def apply_filters(value: int | float, filters: list) -> int | float:
"""Iterate over a list of filters and apply them to the supplied value.

Args:
value (Union[int, float]): The value to be filtered.
value (int | float): The value to be filtered.
filters (list): The filters applied to the value.

Returns:
Union[int, float]: The filtered value.
int | float: The filtered value.
"""
actions = {"Low Number Suppression": low_number_suppression, "Rounding": rounding}
result = value
Expand Down
24 changes: 24 additions & 0 deletions src/hutch_bunny/core/results_modifiers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json


def results_modifiers(
low_number_suppression_threshold: int,
rounding_target: int,
Expand All @@ -18,3 +21,24 @@ def results_modifiers(
}
)
return results_modifiers


def get_results_modifiers_from_str(params: str) -> list:
"""Deserialise a JSON list containing results modifiers

Args:
params (str):
The JSON string containing list of parameter objects for results modifiers

Raises:
ValueError: The parsed string does not produce a list

Returns:
list: The list of parameter dicts of results modifiers
"""
deserialised_params = json.loads(params)
if not isinstance(deserialised_params, list):
raise ValueError(
f"{get_results_modifiers_from_str.__name__} requires a JSON list"
)
return deserialised_params
Empty file added tests/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions tests/test_obfuscation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from hutch_bunny.core.obfuscation import (
apply_filters,
low_number_suppression,
rounding,
)


def test_low_number_suppression():
# Test that the threshold is applied
assert low_number_suppression(99, threshold=100) == 0
assert low_number_suppression(100, threshold=100) == 0
assert low_number_suppression(101, threshold=100) == 101

# Test that the threshold can be set to 0
assert low_number_suppression(1, threshold=0) == 1

# Test negative threshold is ignored
assert low_number_suppression(1, threshold=-5) == 1


def test_rounding():
# Test default nearest
assert rounding(9) == 10

# Test rounding is applied
assert rounding(123, nearest=100) == 100
assert rounding(123, nearest=10) == 120
assert rounding(123, nearest=1) == 123

# Test rounding is applied the boundary
assert rounding(150, nearest=100) == 200


def test_apply_filters_rounding():
# Test rounding only
filters = [{"id": "Rounding", "nearest": 100}]
assert apply_filters(123, filters=filters) == 100


def test_apply_filters_low_number_suppression():
# Test low number suppression only
filters = [{"id": "Low Number Suppression", "threshold": 100}]
assert apply_filters(123, filters=filters) == 123


def test_apply_filters_combined():
# Test both filters
filters = [
{"id": "Low Number Suppression", "threshold": 100},
{"id": "Rounding", "nearest": 100},
]
assert apply_filters(123, filters=filters) == 100


def test_apply_filters_combined_leak():
# Test that putting the rounding filter first can leak the low number suppression filter
filters = [
{"id": "Rounding", "nearest": 100},
{"id": "Low Number Suppression", "threshold": 70},
]
assert apply_filters(60, filters=filters) == 100


def test_apply_filters_combined_empty_filter():
# Test that an empty filter list returns the original value
assert apply_filters(9, []) == 9
Loading