Skip to content

Commit

Permalink
fix: Validate entities when running get_online_features (#5031)
Browse files Browse the repository at this point in the history
  • Loading branch information
franciscojavierarceo authored Feb 8, 2025
1 parent ec6f1b7 commit 3bb0dca
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 19 deletions.
53 changes: 38 additions & 15 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,25 +687,48 @@ def _get_unique_entities(
entity_name_to_join_key_map,
join_key_values,
)
# Validate that all expected join keys exist and have non-empty values.
expected_keys = set(entity_name_to_join_key_map.values())
expected_keys.discard("__dummy_id")
missing_keys = sorted(
list(set([key for key in expected_keys if key not in table_entity_values]))
)
empty_keys = sorted(
list(set([key for key in expected_keys if not table_entity_values.get(key)]))
)

# Convert back to rowise.
keys = table_entity_values.keys()
# Sort the rowise data to allow for grouping but keep original index. This lambda is
# sufficient as Entity types cannot be complex (ie. lists).
if missing_keys or empty_keys:
if not any(table_entity_values.values()):
raise KeyError(
f"Missing join key values for keys: {missing_keys}. "
f"No values provided for keys: {empty_keys}. "
f"Provided join_key_values: {list(join_key_values.keys())}"
)

# Convert the column-oriented table_entity_values into row-wise data.
keys = list(table_entity_values.keys())
# Each row is a tuple of ValueProto objects corresponding to the join keys.
rowise = list(enumerate(zip(*table_entity_values.values())))

# If there are no rows, return empty tuples.
if not rowise:
return (), ()

# Sort rowise so that rows with the same join key values are adjacent.
rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1]))

# Identify unique entities and the indexes at which they occur.
unique_entities: Tuple[Dict[str, ValueProto], ...]
indexes: Tuple[List[int], ...]
unique_entities, indexes = tuple(
zip(
*[
(dict(zip(keys, k)), [_[0] for _ in g])
for k, g in itertools.groupby(rowise, key=lambda x: x[1])
]
)
)
# Group rows by their composite join key value.
groups = [
(dict(zip(keys, key_tuple)), [idx for idx, _ in group])
for key_tuple, group in itertools.groupby(rowise, key=lambda row: row[1])
]

# If no groups were formed (should not happen for valid input), return empty tuples.
if not groups:
return (), ()

# Unpack the unique entities and their original row indexes.
unique_entities, indexes = tuple(zip(*groups))
return unique_entities, indexes


Expand Down
15 changes: 15 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ def test_get_online_features() -> None:

assert "trips" in result

with pytest.raises(KeyError) as excinfo:
_ = store.get_online_features(
features=["driver_locations:lon"],
entity_rows=[{"customer_id": 0}],
full_feature_names=False,
).to_dict()

error_message = str(excinfo.value)
assert "Missing join key values for keys:" in error_message
assert (
"Missing join key values for keys: ['customer_id', 'driver_id', 'item_id']."
in error_message
)
assert "Provided join_key_values: ['customer_id']" in error_message

result = store.get_online_features(
features=["customer_profile_pandas_odfv:on_demand_age"],
entity_rows=[{"driver_id": 1, "customer_id": "5"}],
Expand Down
88 changes: 84 additions & 4 deletions sdk/python/tests/unit/test_unit_feature_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Dict, List

import pytest

from feast import utils
from feast.protos.feast.types.Value_pb2 import Value

Expand All @@ -17,7 +19,7 @@ class MockFeatureView:
projection: MockFeatureViewProjection


def test_get_unique_entities():
def test_get_unique_entities_success():
entity_values = {
"entity_1": [Value(int64_val=1), Value(int64_val=2), Value(int64_val=1)],
"entity_2": [
Expand All @@ -41,9 +43,87 @@ def test_get_unique_entities():
join_key_values=entity_values,
entity_name_to_join_key_map=entity_name_to_join_key_map,
)

assert unique_entities == (
expected_entities = (
{"entity_1": Value(int64_val=1), "entity_2": Value(string_val="1")},
{"entity_1": Value(int64_val=2), "entity_2": Value(string_val="2")},
)
assert indexes == ([0, 2], [1])
expected_indexes = ([0, 2], [1])

assert unique_entities == expected_entities
assert indexes == expected_indexes


def test_get_unique_entities_missing_join_key_success():
"""
Tests that _get_unique_entities raises a KeyError when a required join key is missing.
"""
# Here, we omit the required key for "entity_1"
entity_values = {
"entity_2": [
Value(string_val="1"),
Value(string_val="2"),
Value(string_val="1"),
],
}

entity_name_to_join_key_map = {"entity_1": "entity_1", "entity_2": "entity_2"}

fv = MockFeatureView(
name="fv_1",
entities=["entity_1", "entity_2"],
projection=MockFeatureViewProjection(join_key_map={}),
)

unique_entities, indexes = utils._get_unique_entities(
table=fv,
join_key_values=entity_values,
entity_name_to_join_key_map=entity_name_to_join_key_map,
)
expected_entities = (
{"entity_2": Value(string_val="1")},
{"entity_2": Value(string_val="2")},
)
expected_indexes = ([0, 2], [1])

assert unique_entities == expected_entities
assert indexes == expected_indexes
# We're not say anything about the entity_1 missing from the unique_entities list
assert "entity_1" not in [entity.keys() for entity in unique_entities]


def test_get_unique_entities_missing_all_join_keys_error():
"""
Tests that _get_unique_entities raises a KeyError when all required join keys are missing.
"""
entity_values_not_in_feature_view = {
"entity_3": [Value(string_val="3")],
}
entity_name_to_join_key_map = {
"entity_1": "entity_1",
"entity_2": "entity_2",
"entity_3": "entity_3",
}

fv = MockFeatureView(
name="fv_1",
entities=["entity_1", "entity_2"],
projection=MockFeatureViewProjection(join_key_map={}),
)

with pytest.raises(KeyError) as excinfo:
utils._get_unique_entities(
table=fv,
join_key_values=entity_values_not_in_feature_view,
entity_name_to_join_key_map=entity_name_to_join_key_map,
)

error_message = str(excinfo.value)
assert (
"Missing join key values for keys: ['entity_1', 'entity_2', 'entity_3']"
in error_message
)
assert (
"No values provided for keys: ['entity_1', 'entity_2', 'entity_3']"
in error_message
)
assert "Provided join_key_values: ['entity_3']" in error_message

0 comments on commit 3bb0dca

Please sign in to comment.