diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index ff60217f32..86cb08ec93 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -687,25 +687,48 @@ def _get_unique_entities( entity_name_to_join_key_map, join_key_values, ) + # Validate that all expected join keys exist and have non-empty values. + expected_keys = set(entity_name_to_join_key_map.values()) + expected_keys.discard("__dummy_id") + missing_keys = sorted( + list(set([key for key in expected_keys if key not in table_entity_values])) + ) + empty_keys = sorted( + list(set([key for key in expected_keys if not table_entity_values.get(key)])) + ) - # Convert back to rowise. - keys = table_entity_values.keys() - # Sort the rowise data to allow for grouping but keep original index. This lambda is - # sufficient as Entity types cannot be complex (ie. lists). + if missing_keys or empty_keys: + if not any(table_entity_values.values()): + raise KeyError( + f"Missing join key values for keys: {missing_keys}. " + f"No values provided for keys: {empty_keys}. " + f"Provided join_key_values: {list(join_key_values.keys())}" + ) + + # Convert the column-oriented table_entity_values into row-wise data. + keys = list(table_entity_values.keys()) + # Each row is a tuple of ValueProto objects corresponding to the join keys. rowise = list(enumerate(zip(*table_entity_values.values()))) + + # If there are no rows, return empty tuples. + if not rowise: + return (), () + + # Sort rowise so that rows with the same join key values are adjacent. rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1])) - # Identify unique entities and the indexes at which they occur. - unique_entities: Tuple[Dict[str, ValueProto], ...] - indexes: Tuple[List[int], ...] - unique_entities, indexes = tuple( - zip( - *[ - (dict(zip(keys, k)), [_[0] for _ in g]) - for k, g in itertools.groupby(rowise, key=lambda x: x[1]) - ] - ) - ) + # Group rows by their composite join key value. + groups = [ + (dict(zip(keys, key_tuple)), [idx for idx, _ in group]) + for key_tuple, group in itertools.groupby(rowise, key=lambda row: row[1]) + ] + + # If no groups were formed (should not happen for valid input), return empty tuples. + if not groups: + return (), () + + # Unpack the unique entities and their original row indexes. + unique_entities, indexes = tuple(zip(*groups)) return unique_entities, indexes diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 6b0adb6263..9d6b7f3d17 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -137,6 +137,21 @@ def test_get_online_features() -> None: assert "trips" in result + with pytest.raises(KeyError) as excinfo: + _ = store.get_online_features( + features=["driver_locations:lon"], + entity_rows=[{"customer_id": 0}], + full_feature_names=False, + ).to_dict() + + error_message = str(excinfo.value) + assert "Missing join key values for keys:" in error_message + assert ( + "Missing join key values for keys: ['customer_id', 'driver_id', 'item_id']." + in error_message + ) + assert "Provided join_key_values: ['customer_id']" in error_message + result = store.get_online_features( features=["customer_profile_pandas_odfv:on_demand_age"], entity_rows=[{"driver_id": 1, "customer_id": "5"}], diff --git a/sdk/python/tests/unit/test_unit_feature_store.py b/sdk/python/tests/unit/test_unit_feature_store.py index 19a133564f..8d7b32760a 100644 --- a/sdk/python/tests/unit/test_unit_feature_store.py +++ b/sdk/python/tests/unit/test_unit_feature_store.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import Dict, List +import pytest + from feast import utils from feast.protos.feast.types.Value_pb2 import Value @@ -17,7 +19,7 @@ class MockFeatureView: projection: MockFeatureViewProjection -def test_get_unique_entities(): +def test_get_unique_entities_success(): entity_values = { "entity_1": [Value(int64_val=1), Value(int64_val=2), Value(int64_val=1)], "entity_2": [ @@ -41,9 +43,87 @@ def test_get_unique_entities(): join_key_values=entity_values, entity_name_to_join_key_map=entity_name_to_join_key_map, ) - - assert unique_entities == ( + expected_entities = ( {"entity_1": Value(int64_val=1), "entity_2": Value(string_val="1")}, {"entity_1": Value(int64_val=2), "entity_2": Value(string_val="2")}, ) - assert indexes == ([0, 2], [1]) + expected_indexes = ([0, 2], [1]) + + assert unique_entities == expected_entities + assert indexes == expected_indexes + + +def test_get_unique_entities_missing_join_key_success(): + """ + Tests that _get_unique_entities raises a KeyError when a required join key is missing. + """ + # Here, we omit the required key for "entity_1" + entity_values = { + "entity_2": [ + Value(string_val="1"), + Value(string_val="2"), + Value(string_val="1"), + ], + } + + entity_name_to_join_key_map = {"entity_1": "entity_1", "entity_2": "entity_2"} + + fv = MockFeatureView( + name="fv_1", + entities=["entity_1", "entity_2"], + projection=MockFeatureViewProjection(join_key_map={}), + ) + + unique_entities, indexes = utils._get_unique_entities( + table=fv, + join_key_values=entity_values, + entity_name_to_join_key_map=entity_name_to_join_key_map, + ) + expected_entities = ( + {"entity_2": Value(string_val="1")}, + {"entity_2": Value(string_val="2")}, + ) + expected_indexes = ([0, 2], [1]) + + assert unique_entities == expected_entities + assert indexes == expected_indexes + # We're not say anything about the entity_1 missing from the unique_entities list + assert "entity_1" not in [entity.keys() for entity in unique_entities] + + +def test_get_unique_entities_missing_all_join_keys_error(): + """ + Tests that _get_unique_entities raises a KeyError when all required join keys are missing. + """ + entity_values_not_in_feature_view = { + "entity_3": [Value(string_val="3")], + } + entity_name_to_join_key_map = { + "entity_1": "entity_1", + "entity_2": "entity_2", + "entity_3": "entity_3", + } + + fv = MockFeatureView( + name="fv_1", + entities=["entity_1", "entity_2"], + projection=MockFeatureViewProjection(join_key_map={}), + ) + + with pytest.raises(KeyError) as excinfo: + utils._get_unique_entities( + table=fv, + join_key_values=entity_values_not_in_feature_view, + entity_name_to_join_key_map=entity_name_to_join_key_map, + ) + + error_message = str(excinfo.value) + assert ( + "Missing join key values for keys: ['entity_1', 'entity_2', 'entity_3']" + in error_message + ) + assert ( + "No values provided for keys: ['entity_1', 'entity_2', 'entity_3']" + in error_message + ) + assert "Provided join_key_values: ['entity_3']" in error_message