-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add daily_to_weekly function #478
Merged
dylanhmorris
merged 7 commits into
main
from
477-aggregate-daily-predicted-incident-hospitalization-to-epiweekly
Dec 30, 2024
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
4d5cb0a
add daily_to_epiweekly
sbidari 3fb2256
Apply suggestions from code review
sbidari 6f3f87f
make dow configurable for input data and weekly aggregated data
sbidari 00ed69a
pre-commit
sbidari 73b3128
Apply suggestions from code review
sbidari 2f8da43
code review suggestions
sbidari 57a60c9
Apply suggestions from code review
dylanhmorris File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# numpydoc ignore=GL08 | ||
|
||
import jax.numpy as jnp | ||
import pytest | ||
|
||
from pyrenew.convolve import daily_to_mmwr_epiweekly, daily_to_weekly | ||
|
||
|
||
def test_daily_to_weekly_no_offset(): | ||
""" | ||
Tests that the function correctly aggregates | ||
daily values into weekly totals when there | ||
is no offset both input and output start dow on Monday. | ||
""" | ||
daily_values = jnp.arange(1, 15) | ||
result = daily_to_weekly(daily_values) | ||
expected = jnp.array([28, 77]) | ||
assert jnp.array_equal(result, expected) | ||
|
||
|
||
def test_daily_to_weekly_with_input_data_offset(): | ||
""" | ||
Tests that the function correctly aggregates | ||
daily values into weekly totals with dow | ||
offset in the input data. | ||
""" | ||
daily_values = jnp.arange(1, 15) | ||
result = daily_to_weekly(daily_values, input_data_first_dow=2) | ||
expected = jnp.array([63]) | ||
assert jnp.array_equal(result, expected) | ||
|
||
|
||
def test_daily_to_weekly_with_different_week_start(): | ||
""" | ||
Tests aggregation when the desired week start | ||
differs from the input data start. | ||
""" | ||
daily_values = jnp.arange(1, 15) | ||
result = daily_to_weekly( | ||
daily_values, input_data_first_dow=2, week_start_dow=5 | ||
) | ||
expected = jnp.array([49]) | ||
assert jnp.array_equal(result, expected) | ||
|
||
|
||
def test_daily_to_weekly_incomplete_week(): | ||
""" | ||
Tests that the function raises a | ||
ValueError when there are | ||
insufficient daily values to | ||
form a complete week. | ||
""" | ||
daily_values = jnp.arange(1, 5) | ||
with pytest.raises( | ||
ValueError, match="No complete weekly values available" | ||
): | ||
daily_to_weekly(daily_values, input_data_first_dow=0) | ||
|
||
|
||
def test_daily_to_weekly_missing_daily_values(): | ||
""" | ||
Tests that the function correctly | ||
aggregates the available daily values | ||
into weekly values when there are | ||
fewer daily values than required for | ||
complete weekly totals in the final week. | ||
""" | ||
daily_values = jnp.arange(1, 10) | ||
result = daily_to_weekly(daily_values, input_data_first_dow=0) | ||
expected = jnp.array([28]) | ||
assert jnp.array_equal(result, expected) | ||
|
||
|
||
def test_daily_to_weekly_invalid_offset(): | ||
""" | ||
Tests that the function raises a | ||
ValueError when the offset is | ||
outside the valid range (0-6). | ||
""" | ||
daily_values = jnp.arange(1, 15) | ||
with pytest.raises( | ||
ValueError, | ||
match="First day of the week for input timeseries must be between 0 and 6.", | ||
): | ||
daily_to_weekly(daily_values, input_data_first_dow=-1) | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match="Week start date for output aggregated values must be between 0 and 6.", | ||
): | ||
daily_to_weekly(daily_values, week_start_dow=7) | ||
|
||
|
||
def test_daily_to_mmwr_epiweekly(): | ||
""" | ||
Tests aggregation for MMWR epidemiological week. | ||
""" | ||
daily_values = jnp.arange(1, 15) | ||
result = daily_to_mmwr_epiweekly(daily_values) | ||
expected = jnp.array([70]) | ||
assert jnp.array_equal(result, expected) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future we may want to enforce integer-like values more strongly or conversely expand to support arrays of different potentially, but I think this is fine for now.