From 597d2e45df19ff77806989095d70b709142eddb2 Mon Sep 17 00:00:00 2001 From: Todd Gaugler Date: Wed, 12 Feb 2025 12:55:06 -0500 Subject: [PATCH] Adding SignalsDateRangeContext ... --- ccflow/context.py | 7 ++++++- ccflow/tests/test_context.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/ccflow/context.py b/ccflow/context.py index 0bafd59..efb0827 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,7 +1,7 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Generic, Hashable, List, Optional, Sequence, Set, TypeVar from pydantic import field_validator, model_validator @@ -25,6 +25,7 @@ "FreqHorizonDateContext", "FreqHorizonDateRangeContext", "SeededDateRangeContext", + "SignalsDateRangeContext", "SourceContext", "UniverseContext", "UniverseDateContext", @@ -111,6 +112,10 @@ class SeededDateRangeContext(DateRangeContext): seed: int = 1234 +class SignalsDateRangeContext(DateRangeContext): + signals: List[str] + + class VersionedDateContext(DateContext, EntryTimeContext): pass diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index 3cac9a7..7507e97 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -17,6 +17,7 @@ ModelDateRangeSourceContext, ModelFreqDateRangeContext, NullContext, + SignalsDateRangeContext, UniverseContext, UniverseDateContext, UniverseDateRangeContext, @@ -75,6 +76,14 @@ def test_date_range(self): self.assertEqual(MyRangeModel(context=("-1d", "0d")).context, c) self.assertEqual(MyRangeModel(context=["-1d", "0d"]).context, c) + def test_signal_date_range(self): + d0 = date.today() - timedelta(1) + d1 = date.today() + c = SignalsDateRangeContext(start_date=d0, end_date=d1, signals=["a", "b"]) + self.assertEqual(SignalsDateRangeContext(start_date=str(d0), end_date=pd.Timestamp(date.today()), signals=["a", "b"]), c) + self.assertEqual(SignalsDateRangeContext(start_date="-1d", end_date="0d", signals=["a", "b"]), c) + self.assertRaises(ValueError, SignalsDateRangeContext, start_date=d0, end_date=d1, signals="foobar") + def test_freq(self): self.assertEqual( FreqDateContext.model_validate("5min,2022-01-01"),