diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 814b7da14..06beb2633 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -91,8 +91,8 @@ def __init__( azure_endpoint: Secret = Secret.from_env_var("AZURE_AI_SEARCH_ENDPOINT", strict=True), # noqa: B008 index_name: str = "default", embedding_dimension: int = 768, - metadata_fields: Optional[Dict[str, type]] = None, - vector_search_configuration: VectorSearch = None, + metadata_fields: Optional[Dict[str, SearchField | type]] = None, + vector_search_configuration: Optional[VectorSearch] = None, **index_creation_kwargs, ): """ @@ -103,10 +103,10 @@ def __init__( :param api_key: The API key to use for authentication. :param index_name: Name of index in Azure AI Search, if it doesn't exist it will be created. :param embedding_dimension: Dimension of the embeddings. - :param metadata_fields: A dictionary of metadata keys and their types to create - additional fields in index schema. As fields in Azure SearchIndex cannot be dynamic, - it is necessary to specify the metadata fields in advance. - (e.g. metadata_fields = {"author": str, "date": datetime}) + :param metadata_fields: A dictionary of metadata keys and search index fields to map them + to. These fields will automatically be added during index creation. For convenience, + a mapping may also be a Python type (`str`, `bool`, `int`, `float`, or `datetime`) + instead of a `SearchField`, in which case a simple filterable field is created. :param vector_search_configuration: Configuration option related to vector search. Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. @@ -136,13 +136,12 @@ def __init__( self._index_name = index_name self._embedding_dimension = embedding_dimension self._dummy_vector = [-10.0] * self._embedding_dimension - self._metadata_fields = metadata_fields + self._metadata_fields = self._normalize_metadata_index_fields(metadata_fields) self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH self._index_creation_kwargs = index_creation_kwargs @property def client(self) -> SearchClient: - # resolve secrets for authentication resolved_endpoint = ( self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint @@ -178,6 +177,45 @@ def client(self) -> SearchClient: return self._client + def _normalize_metadata_index_fields( + self, metadata_fields: Optional[Dict[str, SearchField | type]] + ) -> Dict[str, SearchField]: + """Create a list of index fields for storing metadata values.""" + + if not metadata_fields: + return {} + + normalized_fields = {} + + for key, value in metadata_fields.items(): + if isinstance(value, SearchField): + if value.name == key: + normalized_fields[key] = value + else: + msg = f"Name of SearchField ('{value.name}') must match metadata field name ('{key}')" + raise ValueError(msg) + else: + if not key[0].isalpha(): + msg = ( + f"Azure Search index only allows field names starting with letters. " + f"Invalid key: {key} will be dropped." + ) + logger.warning(msg) + continue + + field_type = type_mapping.get(value) + if not field_type: + error_message = f"Unsupported field type for key '{key}': {value}" + raise ValueError(error_message) + + normalized_fields[key] = SimpleField( + name=key, + type=field_type, + filterable=True, + ) + + return normalized_fields + def _create_index(self) -> None: """ Internally creates a new search index. @@ -198,13 +236,15 @@ def _create_index(self) -> None: ] if self._metadata_fields: - default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) + default_fields.extend(self._metadata_fields.values()) + index = SearchIndex( name=self._index_name, fields=default_fields, vector_search=self._vector_search_configuration, **self._index_creation_kwargs, ) + if self._index_client: self._index_client.create_index(index) @@ -258,28 +298,19 @@ def _deserialize_index_creation_kwargs(cls, data: Dict[str, Any]) -> Any: return result[key] def to_dict(self) -> Dict[str, Any]: - # This is not the best solution to serialise this class but is the fastest to implement. - # Not all kwargs types can be serialised to text so this can fail. We must serialise each - # type explicitly to handle this properly. """ Serializes the component to a dictionary. :returns: Dictionary with serialized data. """ - - if self._metadata_fields: - serialized_metadata = {key: value.__name__ for key, value in self._metadata_fields.items()} - else: - serialized_metadata = None - return default_to_dict( self, azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None, api_key=self._api_key.to_dict() if self._api_key else None, index_name=self._index_name, embedding_dimension=self._embedding_dimension, - metadata_fields=serialized_metadata, + metadata_fields={key: value.as_dict() for key, value in self._metadata_fields.items()}, vector_search_configuration=self._vector_search_configuration.as_dict(), **self._serialize_index_creation_kwargs(self._index_creation_kwargs), ) @@ -296,7 +327,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore": Deserialized component. """ if (fields := data["init_parameters"]["metadata_fields"]) is not None: - data["init_parameters"]["metadata_fields"] = cls._deserialize_metadata_fields(fields) + data["init_parameters"]["metadata_fields"] = { + key: SearchField.from_dict(field) for key, field in fields.items() + } + else: + data["init_parameters"]["metadata_fields"] = {} for key, _value in AZURE_CLASS_MAPPING.items(): if key in data["init_parameters"]: @@ -454,46 +489,12 @@ def _convert_haystack_document_to_azure(self, document: Document) -> Dict[str, A return index_document - def _create_metadata_index_fields(self, metadata: Dict[str, Any]) -> List[SimpleField]: - """Create a list of index fields for storing metadata values.""" - - index_fields = [] - metadata_field_mapping = self._map_metadata_field_types(metadata) - - for key, field_type in metadata_field_mapping.items(): - index_fields.append(SimpleField(name=key, type=field_type, filterable=True)) - - return index_fields - - def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]: - """Map metadata field types to Azure Search field types.""" - - metadata_field_mapping = {} - - for key, value_type in metadata.items(): - - if not key[0].isalpha(): - msg = ( - f"Azure Search index only allows field names starting with letters. " - f"Invalid key: {key} will be dropped." - ) - logger.warning(msg) - continue - - field_type = type_mapping.get(value_type) - if not field_type: - error_message = f"Unsupported field type for key '{key}': {value_type}" - raise ValueError(error_message) - metadata_field_mapping[key] = field_type - - return metadata_field_mapping - def _embedding_retrieval( self, query_embedding: List[float], *, top_k: int = 10, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[str] = None, **kwargs, ) -> List[Document]: """ @@ -527,7 +528,7 @@ def _bm25_retrieval( self, query: str, top_k: int = 10, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[str] = None, **kwargs, ) -> List[Document]: """ @@ -560,7 +561,7 @@ def _hybrid_retrieval( query: str, query_embedding: List[float], top_k: int = 10, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[str] = None, **kwargs, ) -> List[Document]: """ diff --git a/integrations/azure_ai_search/tests/test_document_store.py b/integrations/azure_ai_search/tests/test_document_store.py index dbd6e5628..384c8543d 100644 --- a/integrations/azure_ai_search/tests/test_document_store.py +++ b/integrations/azure_ai_search/tests/test_document_store.py @@ -8,7 +8,7 @@ from unittest.mock import patch import pytest -from azure.search.documents.indexes.models import CustomAnalyzer, SearchResourceEncryptionKey +from azure.search.documents.indexes.models import CustomAnalyzer, SearchField, SearchResourceEncryptionKey, SimpleField from haystack.dataclasses.document import Document from haystack.errors import FilterError from haystack.testing.document_store import ( @@ -37,7 +37,7 @@ def test_to_dict(monkeypatch): "api_key": {"env_vars": ["AZURE_AI_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, "index_name": "default", "embedding_dimension": 768, - "metadata_fields": None, + "metadata_fields": {}, "vector_search_configuration": { "profiles": [ {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} @@ -70,7 +70,10 @@ def test_to_dict_with_params(monkeypatch): document_store = AzureAISearchDocumentStore( index_name="my_index", embedding_dimension=15, - metadata_fields={"Title": str, "Pages": int}, + metadata_fields={ + "Title": SearchField(name="Title", type="Edm.String", searchable=True, filterable=True), + "Pages": int, + }, encryption_key=encryption_key, analyzers=[analyzer], ) @@ -84,8 +87,8 @@ def test_to_dict_with_params(monkeypatch): "index_name": "my_index", "embedding_dimension": 15, "metadata_fields": { - "Title": "str", - "Pages": "int", + "Title": SimpleField(name="Title", type="Edm.String", searchable=True, filterable=True).as_dict(), + "Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True).as_dict(), }, "encryption_key": { "key_name": "my-key", @@ -136,7 +139,7 @@ def test_from_dict(monkeypatch): assert isinstance(document_store._azure_endpoint, EnvVarSecret) assert document_store._index_name == "default" assert document_store._embedding_dimension == 768 - assert document_store._metadata_fields is None + assert document_store._metadata_fields == {} assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH @@ -157,8 +160,8 @@ def test_from_dict_with_params(monkeypatch): "index_name": "my_index", "embedding_dimension": 15, "metadata_fields": { - "Title": "str", - "Pages": "int", + "Title": SimpleField(name="Title", type="Edm.String", filterable=True).as_dict(), + "Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True).as_dict(), }, "encryption_key": { "key_name": "my-key", @@ -192,7 +195,10 @@ def test_from_dict_with_params(monkeypatch): assert isinstance(document_store._azure_endpoint, EnvVarSecret) assert document_store._index_name == "my_index" assert document_store._embedding_dimension == 15 - assert document_store._metadata_fields == {"Title": str, "Pages": int} + assert document_store._metadata_fields == { + "Title": SimpleField(name="Title", type="Edm.String", filterable=True), + "Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True), + } assert document_store._index_creation_kwargs["encryption_key"] == encryption_key assert document_store._index_creation_kwargs["analyzers"][0].name == "url-analyze" assert document_store._index_creation_kwargs["analyzers"][0].token_filters == ["lowercase"] @@ -218,7 +224,10 @@ def test_init(_mock_azure_search_client): assert document_store._index_name == "my_index" assert document_store._embedding_dimension == 15 - assert document_store._metadata_fields == {"Title": str, "Pages": int} + assert document_store._metadata_fields == { + "Title": SimpleField(name="Title", type="Edm.String", filterable=True), + "Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True), + } assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH @@ -228,7 +237,6 @@ def test_init(_mock_azure_search_client): reason="Missing AZURE_AI_SEARCH_ENDPOINT or AZURE_AI_SEARCH_API_KEY.", ) class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): - def test_write_documents(self, document_store: AzureAISearchDocumentStore): docs = [Document(id="1")] assert document_store.write_documents(docs) == 1 @@ -281,7 +289,6 @@ def _random_embeddings(n): indirect=True, ) class TestFilters(FilterDocumentsTest): - # Overriding to change "date" to compatible ISO 8601 format @pytest.fixture def filterable_docs(self) -> List[Document]: