diff --git a/rectools/dataset/dataset.py b/rectools/dataset/dataset.py index 9d50e66a..afdb8a67 100644 --- a/rectools/dataset/dataset.py +++ b/rectools/dataset/dataset.py @@ -35,20 +35,22 @@ def _serialize_feature_name(spec: tp.Any) -> Hashable: - error_msg = f""" - Serialization for feature name '{spec}' is not supported. - Please convert your feature names and category feature values to strings, numbers, booleans - or their tuples. - """ + type_error = TypeError( + f""" + Serialization for feature name '{spec}' is not supported. + Please convert your feature names and category feature values to strings, numbers, booleans + or their tuples. + """ + ) if isinstance(spec, (list, np.ndarray)): - raise ValueError(error_msg) + raise type_error if isinstance(spec, tuple): return tuple(_serialize_feature_name(item) for item in spec) if isinstance(spec, (int, float, str, bool)): return spec - if np.issubdtype(spec, np.number) or np.issubdtype(spec, np.bool_): + if np.issubdtype(spec, np.number) or np.issubdtype(spec, np.bool_): # str is handled by isinstance(spec, str) return spec.item() - raise ValueError(error_msg) + raise type_error FeatureName = tpe.Annotated[AnyFeatureName, PlainSerializer(_serialize_feature_name, when_used="json")] diff --git a/tests/dataset/test_dataset.py b/tests/dataset/test_dataset.py index 528435a1..d1b3b421 100644 --- a/tests/dataset/test_dataset.py +++ b/tests/dataset/test_dataset.py @@ -520,5 +520,5 @@ def test_basic(self, feature_name: AnyFeatureName, expected: Hashable) -> None: @pytest.mark.parametrize("feature_name", (np.array([1]), [1], np.array(["name"]), np.array([True]))) def test_raises_on_incorrect_input(self, feature_name: tp.Any) -> None: - with pytest.raises(ValueError): + with pytest.raises(TypeError): _serialize_feature_name(feature_name)