Skip to content

Commit

Permalink
feat: Adding support to return additional features from vector retrie…
Browse files Browse the repository at this point in the history
…val for Milvus db (#4971)

* checking in progress but this Pr still is not ready yet

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* feat: Adding new method to FeatureStore to allow more flexible retrieval of features from vector similarity search

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* Adding requested_features back into online_store

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* linter

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* removed type adjustment

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
  • Loading branch information
franciscojavierarceo authored Jan 28, 2025
1 parent 6a1c102 commit 6ce08d3
Show file tree
Hide file tree
Showing 9 changed files with 573 additions and 63 deletions.
145 changes: 143 additions & 2 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from feast.feast_object import FeastObject
from feast.feature_service import FeatureService
from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_NAME, FeatureView
from feast.field import Field
from feast.inference import (
update_data_sources_with_inferred_event_timestamp_col,
update_feature_views_with_inferred_features_and_entities,
Expand Down Expand Up @@ -1833,7 +1834,6 @@ def retrieve_online_documents(
top_k,
distance_metric,
)

# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
entity_key_vals = [feature[1] for feature in document_features]
Expand Down Expand Up @@ -1861,6 +1861,66 @@ def retrieve_online_documents(
)
return OnlineResponse(online_features_response)

def retrieve_online_documents_v2(
self,
query: Union[str, List[float]],
top_k: int,
features: List[str],
distance_metric: Optional[str] = "L2",
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.
Args:
features: The list of features that should be retrieved from the online document store. These features can be
specified either as a list of string document feature references or as a feature service. String feature
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
distance_metric: The distance metric to use for retrieval.
"""
if isinstance(query, str):
raise ValueError(
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
)

(
available_feature_views,
_,
) = utils._get_feature_views_to_use(
registry=self._registry,
project=self.project,
features=features,
allow_cache=True,
hide_dummy_entity=False,
)
feature_view_set = set()
for feature in features:
feature_view_name = feature.split(":")[0]
feature_view = self.get_feature_view(feature_view_name)
feature_view_set.add(feature_view.name)
if len(feature_view_set) > 1:
raise ValueError("Document retrieval only supports a single feature view.")
requested_features = [
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
]

requested_feature_view = available_feature_views[0]
if not requested_feature_view:
raise ValueError(
f"Feature view {requested_feature_view} not found in the registry."
)

provider = self._get_provider()
return self._retrieve_from_online_store_v2(
provider,
requested_feature_view,
requested_features,
query,
top_k,
distance_metric,
)

def _retrieve_from_online_store(
self,
provider: Provider,
Expand All @@ -1878,6 +1938,10 @@ def _retrieve_from_online_store(
"""
Search and return document features from the online document store.
"""
vector_field_metadata = _get_feature_view_vector_field_metadata(table)
if vector_field_metadata:
distance_metric = vector_field_metadata.vector_search_metric

documents = provider.retrieve_online_documents(
config=self.config,
table=table,
Expand All @@ -1891,7 +1955,7 @@ def _retrieve_from_online_store(
read_row_protos = []
row_ts_proto = Timestamp()

for row_ts, entity_key, feature_val, vector_value, distance_val in documents:
for row_ts, entity_key, feature_val, vector_value, distance_val in documents: # type: ignore[misc]
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
Expand All @@ -1916,6 +1980,70 @@ def _retrieve_from_online_store(
)
return read_row_protos

def _retrieve_from_online_store_v2(
self,
provider: Provider,
table: FeatureView,
requested_features: List[str],
query: List[float],
top_k: int,
distance_metric: Optional[str],
) -> OnlineResponse:
"""
Search and return document features from the online document store.
"""
vector_field_metadata = _get_feature_view_vector_field_metadata(table)
if vector_field_metadata:
distance_metric = vector_field_metadata.vector_search_metric

documents = provider.retrieve_online_documents_v2(
config=self.config,
table=table,
requested_features=requested_features,
query=query,
top_k=top_k,
distance_metric=distance_metric,
)

entity_key_dict: Dict[str, List[ValueProto]] = {}
datevals, entityvals, list_of_feature_dicts = [], [], []
for row_ts, entity_key, feature_dict in documents: # type: ignore[misc]
datevals.append(row_ts)
entityvals.append(entity_key)
list_of_feature_dicts.append(feature_dict)
if entity_key:
for key, value in zip(entity_key.join_keys, entity_key.entity_values):
python_value = value
if key not in entity_key_dict:
entity_key_dict[key] = []
entity_key_dict[key].append(python_value)

table_entity_values, idxs = utils._get_unique_entities_from_values(
entity_key_dict,
)

features_to_request: List[str] = []
if requested_features:
features_to_request = requested_features + ["distance"]
else:
features_to_request = ["distance"]
feature_data = utils._convert_rows_to_protobuf(
requested_features=features_to_request,
read_rows=list(zip(datevals, list_of_feature_dicts)),
)

online_features_response = GetOnlineFeaturesResponse(results=[])
utils._populate_response_from_feature_data(
feature_data=feature_data,
indexes=idxs,
online_features_response=online_features_response,
full_feature_names=False,
requested_features=features_to_request,
table=table,
)

return OnlineResponse(online_features_response)

def serve(
self,
host: str,
Expand Down Expand Up @@ -2265,3 +2393,16 @@ def _validate_data_sources(data_sources: List[DataSource]):
raise DataSourceRepeatNamesException(case_insensitive_ds_name)
else:
ds_names.add(case_insensitive_ds_name)


def _get_feature_view_vector_field_metadata(
feature_view: FeatureView,
) -> Optional[Field]:
vector_fields = [field for field in feature_view.schema if field.vector_index]
if len(vector_fields) > 1:
raise ValueError(
f"Feature view {feature_view.name} has multiple vector fields. Only one vector field per feature view is supported."
)
if not vector_fields:
return None
return vector_fields[0]
Loading

0 comments on commit 6ce08d3

Please sign in to comment.