From 6ce08d31863b12a7a92bf5207172a05f8da077d1 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Tue, 28 Jan 2025 16:39:11 -0500 Subject: [PATCH] feat: Adding support to return additional features from vector retrieval for Milvus db (#4971) * checking in progress but this Pr still is not ready yet Signed-off-by: Francisco Javier Arceo * feat: Adding new method to FeatureStore to allow more flexible retrieval of features from vector similarity search Signed-off-by: Francisco Javier Arceo * Adding requested_features back into online_store Signed-off-by: Francisco Javier Arceo * linter Signed-off-by: Francisco Javier Arceo * removed type adjustment Signed-off-by: Francisco Javier Arceo --------- Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/feature_store.py | 145 +++++++++++++++++- .../milvus_online_store/milvus.py | 142 +++++++++++------ .../feast/infra/online_stores/online_store.py | 35 +++++ .../feast/infra/passthrough_provider.py | 21 +++ sdk/python/feast/infra/provider.py | 35 ++++- sdk/python/feast/utils.py | 93 +++++++++++ .../example_repos/example_rag_feature_repo.py | 15 +- sdk/python/tests/foo_provider.py | 17 ++ .../online_store/test_online_retrieval.py | 133 ++++++++++++++-- 9 files changed, 573 insertions(+), 63 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index d0e6f1124c0..0f092538cf0 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -62,6 +62,7 @@ from feast.feast_object import FeastObject from feast.feature_service import FeatureService from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_NAME, FeatureView +from feast.field import Field from feast.inference import ( update_data_sources_with_inferred_event_timestamp_col, update_feature_views_with_inferred_features_and_entities, @@ -1833,7 +1834,6 @@ def retrieve_online_documents( top_k, distance_metric, ) - # TODO currently not return the vector value since it is same as feature value, if embedding is supported, # the feature value can be raw text before embedded entity_key_vals = [feature[1] for feature in document_features] @@ -1861,6 +1861,66 @@ def retrieve_online_documents( ) return OnlineResponse(online_features_response) + def retrieve_online_documents_v2( + self, + query: Union[str, List[float]], + top_k: int, + features: List[str], + distance_metric: Optional[str] = "L2", + ) -> OnlineResponse: + """ + Retrieves the top k closest document features. Note, embeddings are a subset of features. + + Args: + features: The list of features that should be retrieved from the online document store. These features can be + specified either as a list of string document feature references or as a feature service. String feature + references must have format "feature_view:feature", e.g, "document_fv:document_embeddings". + query: The query to retrieve the closest document features for. + top_k: The number of closest document features to retrieve. + distance_metric: The distance metric to use for retrieval. + """ + if isinstance(query, str): + raise ValueError( + "Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents." + ) + + ( + available_feature_views, + _, + ) = utils._get_feature_views_to_use( + registry=self._registry, + project=self.project, + features=features, + allow_cache=True, + hide_dummy_entity=False, + ) + feature_view_set = set() + for feature in features: + feature_view_name = feature.split(":")[0] + feature_view = self.get_feature_view(feature_view_name) + feature_view_set.add(feature_view.name) + if len(feature_view_set) > 1: + raise ValueError("Document retrieval only supports a single feature view.") + requested_features = [ + f.split(":")[1] for f in features if isinstance(f, str) and ":" in f + ] + + requested_feature_view = available_feature_views[0] + if not requested_feature_view: + raise ValueError( + f"Feature view {requested_feature_view} not found in the registry." + ) + + provider = self._get_provider() + return self._retrieve_from_online_store_v2( + provider, + requested_feature_view, + requested_features, + query, + top_k, + distance_metric, + ) + def _retrieve_from_online_store( self, provider: Provider, @@ -1878,6 +1938,10 @@ def _retrieve_from_online_store( """ Search and return document features from the online document store. """ + vector_field_metadata = _get_feature_view_vector_field_metadata(table) + if vector_field_metadata: + distance_metric = vector_field_metadata.vector_search_metric + documents = provider.retrieve_online_documents( config=self.config, table=table, @@ -1891,7 +1955,7 @@ def _retrieve_from_online_store( read_row_protos = [] row_ts_proto = Timestamp() - for row_ts, entity_key, feature_val, vector_value, distance_val in documents: + for row_ts, entity_key, feature_val, vector_value, distance_val in documents: # type: ignore[misc] # Reset timestamp to default or update if row_ts is not None if row_ts is not None: row_ts_proto.FromDatetime(row_ts) @@ -1916,6 +1980,70 @@ def _retrieve_from_online_store( ) return read_row_protos + def _retrieve_from_online_store_v2( + self, + provider: Provider, + table: FeatureView, + requested_features: List[str], + query: List[float], + top_k: int, + distance_metric: Optional[str], + ) -> OnlineResponse: + """ + Search and return document features from the online document store. + """ + vector_field_metadata = _get_feature_view_vector_field_metadata(table) + if vector_field_metadata: + distance_metric = vector_field_metadata.vector_search_metric + + documents = provider.retrieve_online_documents_v2( + config=self.config, + table=table, + requested_features=requested_features, + query=query, + top_k=top_k, + distance_metric=distance_metric, + ) + + entity_key_dict: Dict[str, List[ValueProto]] = {} + datevals, entityvals, list_of_feature_dicts = [], [], [] + for row_ts, entity_key, feature_dict in documents: # type: ignore[misc] + datevals.append(row_ts) + entityvals.append(entity_key) + list_of_feature_dicts.append(feature_dict) + if entity_key: + for key, value in zip(entity_key.join_keys, entity_key.entity_values): + python_value = value + if key not in entity_key_dict: + entity_key_dict[key] = [] + entity_key_dict[key].append(python_value) + + table_entity_values, idxs = utils._get_unique_entities_from_values( + entity_key_dict, + ) + + features_to_request: List[str] = [] + if requested_features: + features_to_request = requested_features + ["distance"] + else: + features_to_request = ["distance"] + feature_data = utils._convert_rows_to_protobuf( + requested_features=features_to_request, + read_rows=list(zip(datevals, list_of_feature_dicts)), + ) + + online_features_response = GetOnlineFeaturesResponse(results=[]) + utils._populate_response_from_feature_data( + feature_data=feature_data, + indexes=idxs, + online_features_response=online_features_response, + full_feature_names=False, + requested_features=features_to_request, + table=table, + ) + + return OnlineResponse(online_features_response) + def serve( self, host: str, @@ -2265,3 +2393,16 @@ def _validate_data_sources(data_sources: List[DataSource]): raise DataSourceRepeatNamesException(case_insensitive_ds_name) else: ds_names.add(case_insensitive_ds_name) + + +def _get_feature_view_vector_field_metadata( + feature_view: FeatureView, +) -> Optional[Field]: + vector_fields = [field for field in feature_view.schema if field.vector_index] + if len(vector_fields) > 1: + raise ValueError( + f"Feature view {feature_view.name} has multiple vector fields. Only one vector field per feature view is supported." + ) + if not vector_fields: + return None + return vector_fields[0] diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index ff3cf62b3ad..e39db6d3a34 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -15,6 +15,7 @@ from feast.feature_view import FeatureView from feast.infra.infra_object import InfraObject from feast.infra.key_encoding_utils import ( + deserialize_entity_key, serialize_entity_key, ) from feast.infra.online_stores.online_store import OnlineStore @@ -24,7 +25,10 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.type_map import PROTO_VALUE_TO_VALUE_TYPE_MAP +from feast.type_map import ( + PROTO_VALUE_TO_VALUE_TYPE_MAP, + feast_value_type_to_python_type, +) from feast.types import ( VALUE_TYPES_TO_FEAST_TYPES, Array, @@ -33,7 +37,6 @@ ValueType, ) from feast.utils import ( - _build_retrieve_online_document_record, _serialize_vector_to_float_list, to_naive_utc, ) @@ -89,7 +92,7 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): host: Optional[StrictStr] = "localhost" port: Optional[int] = 19530 index_type: Optional[str] = "FLAT" - metric_type: Optional[str] = "L2" + metric_type: Optional[str] = "COSINE" embedding_dim: Optional[int] = 128 vector_enabled: Optional[bool] = True nlist: Optional[int] = 128 @@ -170,16 +173,14 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, A dim=config.online_store.embedding_dim, ) ) - elif dtype == DataType.VARCHAR: + else: fields.append( FieldSchema( name=field.name, - dtype=dtype, + dtype=DataType.VARCHAR, max_length=512, ) ) - else: - fields.append(FieldSchema(name=field.name, dtype=dtype)) schema = CollectionSchema( fields=fields, description="Feast feature view data" @@ -234,6 +235,7 @@ def online_write_batch( ) -> None: self.client = self._connect(config) collection = self._get_collection(config, table) + vector_cols = [f.name for f in table.features if f.vector_index] entity_batch_to_insert = [] for entity_key, values_dict, timestamp, created_ts in data: # need to construct the composite primary key also need to handle the fact that entities are a list @@ -241,6 +243,8 @@ def online_write_batch( entity_key, entity_key_serialization_version=config.entity_key_serialization_version, ).hex() + # to recover the entity key just run: + # deserialize_entity_key(bytes.fromhex(entity_key_str), entity_key_serialization_version=3) composite_key_name = ( "_".join([str(value) for value in entity_key.join_keys]) + "_pk" ) @@ -248,11 +252,18 @@ def online_write_batch( created_ts_int = ( int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0 ) - values_dict = _extract_proto_values_to_dict(values_dict) - entity_dict = _extract_proto_values_to_dict( - dict(zip(entity_key.join_keys, entity_key.entity_values)) - ) + entity_dict = { + join_key: feast_value_type_to_python_type(value) + for join_key, value in zip( + entity_key.join_keys, entity_key.entity_values + ) + } values_dict.update(entity_dict) + values_dict = _extract_proto_values_to_dict( + values_dict, + vector_cols=vector_cols, + serialize_to_string=True, + ) single_entity_record = { composite_key_name: entity_key_str, @@ -316,12 +327,11 @@ def teardown( self.client.drop_collection(collection_name) self._collections.pop(collection_name, None) - def retrieve_online_documents( + def retrieve_online_documents_v2( self, config: RepoConfig, table: FeatureView, - requested_feature: Optional[str], - requested_features: Optional[List[str]], + requested_features: List[str], embedding: List[float], top_k: int, distance_metric: Optional[str] = None, @@ -329,11 +339,12 @@ def retrieve_online_documents( Tuple[ Optional[datetime], Optional[EntityKeyProto], - Optional[ValueProto], - Optional[ValueProto], - Optional[ValueProto], + Optional[Dict[str, ValueProto]], ] ]: + entity_name_feast_primitive_type_map = { + k.name: k.dtype for k in table.entity_columns + } self.client = self._connect(config) collection_name = _table_id(config.project, table) collection = self._get_collection(config, table) @@ -344,14 +355,12 @@ def retrieve_online_documents( "metric_type": distance_metric or config.online_store.metric_type, "params": {"nprobe": 10}, } - expr = f"feature_name == '{requested_feature}'" composite_key_name = ( "_".join([str(field.name) for field in table.entity_columns]) + "_pk" ) - if requested_features: - features_str = ", ".join([f"'{f}'" for f in requested_features]) - expr += f" && feature_name in [{features_str}]" + # features_str = ", ".join([f"'{f}'" for f in requested_features]) + # expr = f" && feature_name in [{features_str}]" output_fields = ( [composite_key_name] @@ -387,29 +396,51 @@ def retrieve_online_documents( result_list = [] for hits in results: for hit in hits: - single_record = {} - for field in output_fields: - single_record[field] = hit.get("entity", {}).get(field, None) - + res = {} + res_ts = None entity_key_bytes = bytes.fromhex( hit.get("entity", {}).get(composite_key_name, None) ) - embedding = hit.get("entity", {}).get(ann_search_field) - serialized_embedding = _serialize_vector_to_float_list(embedding) - distance = hit.get("distance", None) - event_ts = datetime.fromtimestamp( - hit.get("entity", {}).get("event_ts") / 1e6 + entity_key_proto = ( + deserialize_entity_key(entity_key_bytes) + if entity_key_bytes + else None ) - prepared_result = _build_retrieve_online_document_record( - entity_key_bytes, - # This may have a bug - serialized_embedding.SerializeToString(), - embedding, - distance, - event_ts, - config.entity_key_serialization_version, + for field in output_fields: + val = ValueProto() + # entity_key_proto = None + if field in ["created_ts", "event_ts"]: + res_ts = datetime.fromtimestamp( + hit.get("entity", {}).get(field) / 1e6 + ) + elif field == ann_search_field: + serialized_embedding = _serialize_vector_to_float_list( + embedding + ) + res[ann_search_field] = serialized_embedding + elif entity_name_feast_primitive_type_map.get( + field, PrimitiveFeastType.INVALID + ) in [ + PrimitiveFeastType.STRING, + PrimitiveFeastType.INT64, + PrimitiveFeastType.INT32, + PrimitiveFeastType.BYTES, + ]: + res[field] = ValueProto( + string_val=hit.get("entity", {}).get(field, "") + ) + elif field == composite_key_name: + pass + else: + val.ParseFromString( + bytes(hit.get("entity", {}).get(field, b"").encode()) + ) + res[field] = val + distance = hit.get("distance", None) + res["distance"] = ( + ValueProto(float_val=distance) if distance else ValueProto() ) - result_list.append(prepared_result) + result_list.append((res_ts, entity_key_proto, res if res else None)) return result_list @@ -417,7 +448,11 @@ def _table_id(project: str, table: FeatureView) -> str: return f"{project}_{table.name}" -def _extract_proto_values_to_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]: +def _extract_proto_values_to_dict( + input_dict: Dict[str, Any], + vector_cols: List[str], + serialize_to_string=False, +) -> Dict[str, Any]: numeric_vector_list_types = [ k for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys() @@ -426,12 +461,27 @@ def _extract_proto_values_to_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]: output_dict = {} for feature_name, feature_values in input_dict.items(): for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP: - if feature_values.HasField(proto_val_type): - if proto_val_type in numeric_vector_list_types: - vector_values = getattr(feature_values, proto_val_type).val - else: - vector_values = getattr(feature_values, proto_val_type) - output_dict[feature_name] = vector_values + if not isinstance(feature_values, (int, float, str)): + if feature_values.HasField(proto_val_type): + if proto_val_type in numeric_vector_list_types: + if serialize_to_string and feature_name not in vector_cols: + vector_values = getattr( + feature_values, proto_val_type + ).SerializeToString() + else: + vector_values = getattr(feature_values, proto_val_type).val + else: + if serialize_to_string: + vector_values = feature_values.SerializeToString().decode() + else: + vector_values = getattr(feature_values, proto_val_type) + output_dict[feature_name] = vector_values + else: + if serialize_to_string: + if not isinstance(feature_values, str): + feature_values = str(feature_values) + output_dict[feature_name] = feature_values + return output_dict diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index be3128562dc..a86fdba4017 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -429,6 +429,41 @@ def retrieve_online_documents( f"Online store {self.__class__.__name__} does not support online retrieval" ) + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Retrieves online feature values for the specified embeddings. + + Args: + distance_metric: distance metric to use for retrieval. + config: The config for the current feature store. + table: The feature view whose feature values should be read. + requested_features: The list of features whose embeddings should be used for retrieval. + embedding: The embeddings to use for retrieval. + top_k: The number of documents to retrieve. + + Returns: + object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple + where the first item is the event timestamp for the row, and the second item is a dict of feature + name to embeddings. + """ + raise NotImplementedError( + f"Online store {self.__class__.__name__} does not support online retrieval" + ) + async def initialize(self, config: RepoConfig) -> None: pass diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 6fe4b6e3a0f..74b05113282 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -313,6 +313,27 @@ def retrieve_online_documents( ) return result + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: Optional[List[str]], + query: List[float], + top_k: int, + distance_metric: Optional[str] = None, + ) -> List: + result = [] + if self.online_store: + result = self.online_store.retrieve_online_documents_v2( + config, + table, + requested_features, + query, + top_k, + distance_metric, + ) + return result + @staticmethod def _prep_rows_to_write_for_ingestion( feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index efc806ba2f0..f765e754436 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -431,7 +431,7 @@ def retrieve_online_documents( Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], - ] + ], ]: """ Searches for the top-k most similar documents in the online document store. @@ -450,6 +450,39 @@ def retrieve_online_documents( """ pass + @abstractmethod + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + query: List[float], + top_k: int, + distance_metric: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Searches for the top-k most similar documents in the online document store. + + Args: + distance_metric: distance metric to use for the search. + config: The config for the current feature store. + table: The feature view whose embeddings should be searched. + requested_features: the requested document feature names. + query: The query embedding to search for. + top_k: The number of documents to return. + + Returns: + A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary + of feature key value pairs + """ + pass + @abstractmethod def validate_data_source( self, diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index cfc19e37ca4..ff60217f32d 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -709,6 +709,35 @@ def _get_unique_entities( return unique_entities, indexes +def _get_unique_entities_from_values( + table_entity_values: Dict[str, List[ValueProto]], +) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]: + """Return the set of unique composite Entities for a Feature View and the indexes at which they appear. + + This method allows us to query the OnlineStore for data we need only once + rather than requesting and processing data for the same combination of + Entities multiple times. + """ + keys = table_entity_values.keys() + # Sort the rowise data to allow for grouping but keep original index. This lambda is + # sufficient as Entity types cannot be complex (ie. lists). + rowise = list(enumerate(zip(*table_entity_values.values()))) + rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1])) + + # Identify unique entities and the indexes at which they occur. + unique_entities: Tuple[Dict[str, ValueProto], ...] + indexes: Tuple[List[int], ...] + unique_entities, indexes = tuple( + zip( + *[ + (dict(zip(keys, k)), [_[0] for _ in g]) + for k, g in itertools.groupby(rowise, key=lambda x: x[1]) + ] + ) + ) + return unique_entities, indexes + + def _drop_unneeded_columns( online_features_response: GetOnlineFeaturesResponse, requested_result_row_names: Set[str], @@ -830,6 +859,70 @@ def _populate_response_from_feature_data( ) +def _populate_response_from_feature_data_v2( + feature_data: Iterable[ + Tuple[ + Iterable[Timestamp], Iterable["FieldStatus.ValueType"], Iterable[ValueProto] + ] + ], + indexes: Iterable[List[int]], + online_features_response: GetOnlineFeaturesResponse, + requested_features: Iterable[str], +): + """Populate the GetOnlineFeaturesResponse with feature data. + + This method assumes that `_read_from_online_store` returns data for each + combination of Entities in `entity_rows` in the same order as they + are provided. + + Args: + feature_data: A list of data in Protobuf form which was retrieved from the OnlineStore. + indexes: A list of indexes which should be the same length as `feature_data`. Each list + of indexes corresponds to a set of result rows in `online_features_response`. + online_features_response: The object to populate. + full_feature_names: A boolean that provides the option to add the feature view prefixes to the feature names, + changing them from the format "feature" to "feature_view__feature" (e.g., "daily_transactions" changes to + "customer_fv__daily_transactions"). + requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the + data in `feature_data`. + """ + # Add the feature names to the response. + requested_feature_refs = [(feature_name) for feature_name in requested_features] + online_features_response.metadata.feature_names.val.extend(requested_feature_refs) + + timestamps, statuses, values = zip(*feature_data) + + # Populate the result with data fetched from the OnlineStore + # which is guaranteed to be aligned with `requested_features`. + for ( + feature_idx, + (timestamp_vector, statuses_vector, values_vector), + ) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))): + online_features_response.results.append( + GetOnlineFeaturesResponse.FeatureVector( + values=apply_list_mapping(values_vector, indexes), + statuses=apply_list_mapping(statuses_vector, indexes), + event_timestamps=apply_list_mapping(timestamp_vector, indexes), + ) + ) + + +def _convert_entity_key_to_proto_to_dict( + entity_key_vals: List[EntityKeyProto], +) -> Dict[str, List[ValueProto]]: + entity_dict: Dict[str, List[ValueProto]] = {} + for entity_key_val in entity_key_vals: + if entity_key_val is not None: + for join_key, entity_value in zip( + entity_key_val.join_keys, entity_key_val.entity_values + ): + if join_key not in entity_dict: + entity_dict[join_key] = [] + # python_entity_value = _proto_value_to_value_type(entity_value) + entity_dict[join_key].append(entity_value) + return entity_dict + + def _get_features( registry, project, diff --git a/sdk/python/tests/example_repos/example_rag_feature_repo.py b/sdk/python/tests/example_repos/example_rag_feature_repo.py index 2f55095bc69..d87a2a34df1 100644 --- a/sdk/python/tests/example_repos/example_rag_feature_repo.py +++ b/sdk/python/tests/example_repos/example_rag_feature_repo.py @@ -1,7 +1,7 @@ from datetime import timedelta from feast import Entity, FeatureView, Field, FileSource -from feast.types import Array, Float32, Int64, UnixTimestamp +from feast.types import Array, Float32, Int64, String, UnixTimestamp, ValueType # This is for Milvus # Note that file source paths are not validated, so there doesn't actually need to be any data @@ -17,20 +17,29 @@ item = Entity( name="item_id", # The name is derived from this argument, not object name. join_keys=["item_id"], + value_type=ValueType.INT64, +) + +author = Entity( + name="author_id", + join_keys=["author_id"], + value_type=ValueType.STRING, ) document_embeddings = FeatureView( name="embedded_documents", - entities=[item], + entities=[item, author], schema=[ Field( name="vector", dtype=Array(Float32), vector_index=True, - vector_search_metric="L2", + vector_search_metric="COSINE", ), Field(name="item_id", dtype=Int64), + Field(name="author_id", dtype=String), Field(name="created_timestamp", dtype=UnixTimestamp), + Field(name="sentence_chunks", dtype=String), Field(name="event_timestamp", dtype=UnixTimestamp), ], source=rag_documents_source, diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 3d1f9219991..ca6a02c4bd0 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -164,6 +164,23 @@ def retrieve_online_documents( ]: return [] + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + query: List[float], + top_k: int, + distance_metric: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + return [] + def validate_data_source( self, config: RepoConfig, diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 5f0796f4eed..20ff2989ebc 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -17,6 +17,7 @@ from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RegistryConfig +from feast.types import ValueType from feast.utils import _utc_now from tests.integration.feature_repos.universal.feature_views import TAGS from tests.utils.cli_repo_creator import CliRunner, get_example_repo @@ -638,7 +639,7 @@ def test_milvus_lite_get_online_documents() -> None: from datetime import timedelta from feast import Entity, FeatureView, Field, FileSource - from feast.types import Array, Float32, Int64, UnixTimestamp + from feast.types import Array, Float32, Int64, String, UnixTimestamp # This is for Milvus # Note that file source paths are not validated, so there doesn't actually need to be any data @@ -654,20 +655,28 @@ def test_milvus_lite_get_online_documents() -> None: item = Entity( name="item_id", # The name is derived from this argument, not object name. join_keys=["item_id"], + value_type=ValueType.INT64, + ) + author = Entity( + name="author_id", + join_keys=["author_id"], + value_type=ValueType.STRING, ) document_embeddings = FeatureView( name="embedded_documents", - entities=[item], + entities=[item, author], schema=[ Field( name="vector", dtype=Array(Float32), vector_index=True, - vector_search_metric="L2", + vector_search_metric="COSINE", ), Field(name="item_id", dtype=Int64), + Field(name="author_id", dtype=String), Field(name="created_timestamp", dtype=UnixTimestamp), + Field(name="sentence_chunks", dtype=String), Field(name="event_timestamp", dtype=UnixTimestamp), ], source=rag_documents_source, @@ -683,12 +692,16 @@ def test_milvus_lite_get_online_documents() -> None: item_keys = [ EntityKeyProto( - join_keys=["item_id"], entity_values=[ValueProto(int64_val=i)] + join_keys=["item_id", "author_id"], + entity_values=[ + ValueProto(int64_val=i), + ValueProto(string_val=f"author_{i}"), + ], ) for i in range(n) ] data = [] - for item_key in item_keys: + for i, item_key in enumerate(item_keys): data.append( ( item_key, @@ -698,8 +711,10 @@ def test_milvus_lite_get_online_documents() -> None: val=np.random.random( vector_length, ) + + i ) - ) + ), + "sentence_chunks": ValueProto(string_val=f"sentence chunk {i}"), }, _utc_now(), _utc_now(), @@ -715,12 +730,15 @@ def test_milvus_lite_get_online_documents() -> None: documents_df = pd.DataFrame( { "item_id": [str(i) for i in range(n)], + "author_id": [f"author_{i}" for i in range(n)], "vector": [ np.random.random( vector_length, ) + + i for i in range(n) ], + "sentence_chunks": [f"sentence chunk {i}" for i in range(n)], "event_timestamp": [_utc_now() for _ in range(n)], "created_timestamp": [_utc_now() for _ in range(n)], } @@ -734,10 +752,103 @@ def test_milvus_lite_get_online_documents() -> None: query_embedding = np.random.random( vector_length, ) - result = store.retrieve_online_documents( - feature="embedded_documents:vector", query=query_embedding, top_k=3 + + client = store._provider._online_store.client + collection_name = client.list_collections()[0] + search_params = { + "metric_type": "COSINE", + "params": {"nprobe": 10}, + } + + results = client.search( + collection_name=collection_name, + data=[query_embedding], + anns_field="vector", + search_params=search_params, + limit=3, + output_fields=[ + "item_id", + "author_id", + "sentence_chunks", + "created_ts", + "event_ts", + ], + ) + result = store.retrieve_online_documents_v2( + features=[ + "embedded_documents:vector", + "embedded_documents:item_id", + "embedded_documents:author_id", + "embedded_documents:sentence_chunks", + ], + query=query_embedding, + top_k=3, ).to_dict() - assert "vector" in result - assert "distance" in result - assert len(result["distance"]) == 3 + for k in ["vector", "item_id", "author_id", "sentence_chunks", "distance"]: + assert k in result, f"Missing {k} in retrieve_online_documents response" + assert len(result["distance"]) == len(results[0]) + + +def test_milvus_native_from_feast_data() -> None: + import random + from datetime import datetime + + import numpy as np + from pymilvus import MilvusClient + + random.seed(42) + VECTOR_LENGTH = 10 # Matches vector_length from the Feast example + COLLECTION_NAME = "embedded_documents" + + # Initialize Milvus client with local setup + client = MilvusClient("./milvus_demo.db") + + # Clear and recreate collection + for collection in client.list_collections(): + client.drop_collection(collection_name=collection) + client.create_collection( + collection_name=COLLECTION_NAME, + dimension=VECTOR_LENGTH, + metric_type="COSINE", # Matches Feast's vector_search_metric + ) + assert client.list_collections() == [COLLECTION_NAME] + + # Prepare data for insertion, similar to the Feast example + n = 10 # Number of items + data = [] + for i in range(n): + vector = (np.random.random(VECTOR_LENGTH) + i).tolist() + data.append( + { + "id": i, + "vector": vector, + "item_id": i, + "author_id": f"author_{i}", + "sentence_chunks": f"sentence chunk {i}", + "event_timestamp": datetime.utcnow().isoformat(), + "created_timestamp": datetime.utcnow().isoformat(), + } + ) + + print("Data has", len(data), "entities, each with fields:", data[0].keys()) + + # Insert data into Milvus + insert_res = client.insert(collection_name=COLLECTION_NAME, data=data) + assert insert_res == {"insert_count": n, "ids": list(range(n)), "cost": 0} + + # Perform a vector search using a random query embedding + query_embedding = (np.random.random(VECTOR_LENGTH)).tolist() + search_res = client.search( + collection_name=COLLECTION_NAME, + data=[query_embedding], + limit=3, # Top 3 results + output_fields=["item_id", "author_id", "sentence_chunks"], + ) + + # Validate the search results + assert len(search_res[0]) == 3 + print("Search Results:", search_res[0]) + + # Clean up the collection + client.drop_collection(collection_name=COLLECTION_NAME)