From 6c870d48e1a130a1638dea898c7db4ce12ffbde2 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 2 May 2024 10:19:34 +0200 Subject: [PATCH] fix: ensure NCMAPSS scaling range is tuple (#61) * fix: ensure scaling range is a tuple * fix: scaling range type hint --- rul_datasets/reader/ncmapss.py | 4 ++-- tests/reader/test_ncmapss.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/rul_datasets/reader/ncmapss.py b/rul_datasets/reader/ncmapss.py index 43db771..67a6915 100644 --- a/rul_datasets/reader/ncmapss.py +++ b/rul_datasets/reader/ncmapss.py @@ -123,7 +123,7 @@ def __init__( truncate_degraded_only: bool = False, resolution_seconds: int = 1, padding_value: float = 0.0, - scaling_range: Optional[Tuple[int, int]] = (0, 1), + scaling_range: Tuple[int, int] = (0, 1), ) -> None: """ Create a new reader for the New C-MAPSS dataset. The maximum RUL value is set @@ -173,7 +173,7 @@ def __init__( self.run_split_dist = run_split_dist or self._get_default_split(self.fd) self.resolution_seconds = resolution_seconds self.padding_value = padding_value - self.scaling_range = scaling_range + self.scaling_range = tuple(scaling_range) if self.resolution_seconds > 1 and window_size is None: warnings.warn( diff --git a/tests/reader/test_ncmapss.py b/tests/reader/test_ncmapss.py index 57d9307..fc4477c 100644 --- a/tests/reader/test_ncmapss.py +++ b/tests/reader/test_ncmapss.py @@ -158,3 +158,11 @@ def test_feature_select(prepared_ncmapss): features, _ = reader.load_complete_split("dev", "dev") assert features[0].shape[2] == 10 + + +@pytest.mark.parametrize("scaling_range", [(0, 1), [0, 1]]) +def test_scaling_range_is_tuple(scaling_range): + reader = NCmapssReader(1, scaling_range=scaling_range) + + assert isinstance(reader.scaling_range, tuple) + assert reader.scaling_range == (0, 1)