Skip to content

feat(azure-ai-search): Allow full metadata field customization #1676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __init__(
azure_endpoint: Secret = Secret.from_env_var("AZURE_AI_SEARCH_ENDPOINT", strict=True), # noqa: B008
index_name: str = "default",
embedding_dimension: int = 768,
metadata_fields: Optional[Dict[str, type]] = None,
vector_search_configuration: VectorSearch = None,
metadata_fields: Optional[Dict[str, SearchField | type]] = None,
vector_search_configuration: Optional[VectorSearch] = None,
**index_creation_kwargs,
):
"""
Expand All @@ -103,10 +103,10 @@ def __init__(
:param api_key: The API key to use for authentication.
:param index_name: Name of index in Azure AI Search, if it doesn't exist it will be created.
:param embedding_dimension: Dimension of the embeddings.
:param metadata_fields: A dictionary of metadata keys and their types to create
additional fields in index schema. As fields in Azure SearchIndex cannot be dynamic,
it is necessary to specify the metadata fields in advance.
(e.g. metadata_fields = {"author": str, "date": datetime})
:param metadata_fields: A dictionary of metadata keys and search index fields to map them
to. These fields will automatically be added during index creation. For convenience,
a mapping may also be a Python type (`str`, `bool`, `int`, `float`, or `datetime`)
instead of a `SearchField`, in which case a simple filterable field is created.
:param vector_search_configuration: Configuration option related to vector search.
Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches.

Expand Down Expand Up @@ -136,13 +136,12 @@ def __init__(
self._index_name = index_name
self._embedding_dimension = embedding_dimension
self._dummy_vector = [-10.0] * self._embedding_dimension
self._metadata_fields = metadata_fields
self._metadata_fields = self._normalize_metadata_index_fields(metadata_fields)
self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH
self._index_creation_kwargs = index_creation_kwargs

@property
def client(self) -> SearchClient:

# resolve secrets for authentication
resolved_endpoint = (
self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint
Expand Down Expand Up @@ -178,6 +177,45 @@ def client(self) -> SearchClient:

return self._client

def _normalize_metadata_index_fields(
self, metadata_fields: Optional[Dict[str, SearchField | type]]
) -> Dict[str, SearchField]:
"""Create a list of index fields for storing metadata values."""

if not metadata_fields:
return {}

normalized_fields = {}

for key, value in metadata_fields.items():
if isinstance(value, SearchField):
if value.name == key:
normalized_fields[key] = value
else:
msg = f"Name of SearchField ('{value.name}') must match metadata field name ('{key}')"
raise ValueError(msg)
else:
if not key[0].isalpha():
msg = (
f"Azure Search index only allows field names starting with letters. "
f"Invalid key: {key} will be dropped."
)
logger.warning(msg)
continue

field_type = type_mapping.get(value)
if not field_type:
error_message = f"Unsupported field type for key '{key}': {value}"
raise ValueError(error_message)

normalized_fields[key] = SimpleField(
name=key,
type=field_type,
filterable=True,
)

return normalized_fields

def _create_index(self) -> None:
"""
Internally creates a new search index.
Expand All @@ -198,13 +236,15 @@ def _create_index(self) -> None:
]

if self._metadata_fields:
default_fields.extend(self._create_metadata_index_fields(self._metadata_fields))
default_fields.extend(self._metadata_fields.values())

index = SearchIndex(
name=self._index_name,
fields=default_fields,
vector_search=self._vector_search_configuration,
**self._index_creation_kwargs,
)

if self._index_client:
self._index_client.create_index(index)

Expand Down Expand Up @@ -258,28 +298,19 @@ def _deserialize_index_creation_kwargs(cls, data: Dict[str, Any]) -> Any:
return result[key]

def to_dict(self) -> Dict[str, Any]:
# This is not the best solution to serialise this class but is the fastest to implement.
# Not all kwargs types can be serialised to text so this can fail. We must serialise each
# type explicitly to handle this properly.
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""

if self._metadata_fields:
serialized_metadata = {key: value.__name__ for key, value in self._metadata_fields.items()}
else:
serialized_metadata = None

return default_to_dict(
self,
azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None,
api_key=self._api_key.to_dict() if self._api_key else None,
index_name=self._index_name,
embedding_dimension=self._embedding_dimension,
metadata_fields=serialized_metadata,
metadata_fields={key: value.as_dict() for key, value in self._metadata_fields.items()},
vector_search_configuration=self._vector_search_configuration.as_dict(),
**self._serialize_index_creation_kwargs(self._index_creation_kwargs),
)
Expand All @@ -296,7 +327,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore":
Deserialized component.
"""
if (fields := data["init_parameters"]["metadata_fields"]) is not None:
data["init_parameters"]["metadata_fields"] = cls._deserialize_metadata_fields(fields)
data["init_parameters"]["metadata_fields"] = {
key: SearchField.from_dict(field) for key, field in fields.items()
}
else:
data["init_parameters"]["metadata_fields"] = {}

for key, _value in AZURE_CLASS_MAPPING.items():
if key in data["init_parameters"]:
Expand Down Expand Up @@ -454,46 +489,12 @@ def _convert_haystack_document_to_azure(self, document: Document) -> Dict[str, A

return index_document

def _create_metadata_index_fields(self, metadata: Dict[str, Any]) -> List[SimpleField]:
"""Create a list of index fields for storing metadata values."""

index_fields = []
metadata_field_mapping = self._map_metadata_field_types(metadata)

for key, field_type in metadata_field_mapping.items():
index_fields.append(SimpleField(name=key, type=field_type, filterable=True))

return index_fields

def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]:
"""Map metadata field types to Azure Search field types."""

metadata_field_mapping = {}

for key, value_type in metadata.items():

if not key[0].isalpha():
msg = (
f"Azure Search index only allows field names starting with letters. "
f"Invalid key: {key} will be dropped."
)
logger.warning(msg)
continue

field_type = type_mapping.get(value_type)
if not field_type:
error_message = f"Unsupported field type for key '{key}': {value_type}"
raise ValueError(error_message)
metadata_field_mapping[key] = field_type

return metadata_field_mapping

def _embedding_retrieval(
self,
query_embedding: List[float],
*,
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None,
filters: Optional[str] = None,
**kwargs,
) -> List[Document]:
"""
Expand Down Expand Up @@ -527,7 +528,7 @@ def _bm25_retrieval(
self,
query: str,
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None,
filters: Optional[str] = None,
**kwargs,
) -> List[Document]:
"""
Expand Down Expand Up @@ -560,7 +561,7 @@ def _hybrid_retrieval(
query: str,
query_embedding: List[float],
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None,
filters: Optional[str] = None,
**kwargs,
) -> List[Document]:
"""
Expand Down
31 changes: 19 additions & 12 deletions integrations/azure_ai_search/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from unittest.mock import patch

import pytest
from azure.search.documents.indexes.models import CustomAnalyzer, SearchResourceEncryptionKey
from azure.search.documents.indexes.models import CustomAnalyzer, SearchField, SearchResourceEncryptionKey, SimpleField
from haystack.dataclasses.document import Document
from haystack.errors import FilterError
from haystack.testing.document_store import (
Expand Down Expand Up @@ -37,7 +37,7 @@ def test_to_dict(monkeypatch):
"api_key": {"env_vars": ["AZURE_AI_SEARCH_API_KEY"], "strict": False, "type": "env_var"},
"index_name": "default",
"embedding_dimension": 768,
"metadata_fields": None,
"metadata_fields": {},
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
Expand Down Expand Up @@ -70,7 +70,10 @@ def test_to_dict_with_params(monkeypatch):
document_store = AzureAISearchDocumentStore(
index_name="my_index",
embedding_dimension=15,
metadata_fields={"Title": str, "Pages": int},
metadata_fields={
"Title": SearchField(name="Title", type="Edm.String", searchable=True, filterable=True),
"Pages": int,
},
encryption_key=encryption_key,
analyzers=[analyzer],
)
Expand All @@ -84,8 +87,8 @@ def test_to_dict_with_params(monkeypatch):
"index_name": "my_index",
"embedding_dimension": 15,
"metadata_fields": {
"Title": "str",
"Pages": "int",
"Title": SimpleField(name="Title", type="Edm.String", searchable=True, filterable=True).as_dict(),
"Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True).as_dict(),
},
"encryption_key": {
"key_name": "my-key",
Expand Down Expand Up @@ -136,7 +139,7 @@ def test_from_dict(monkeypatch):
assert isinstance(document_store._azure_endpoint, EnvVarSecret)
assert document_store._index_name == "default"
assert document_store._embedding_dimension == 768
assert document_store._metadata_fields is None
assert document_store._metadata_fields == {}
assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH


Expand All @@ -157,8 +160,8 @@ def test_from_dict_with_params(monkeypatch):
"index_name": "my_index",
"embedding_dimension": 15,
"metadata_fields": {
"Title": "str",
"Pages": "int",
"Title": SimpleField(name="Title", type="Edm.String", filterable=True).as_dict(),
"Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True).as_dict(),
},
"encryption_key": {
"key_name": "my-key",
Expand Down Expand Up @@ -192,7 +195,10 @@ def test_from_dict_with_params(monkeypatch):
assert isinstance(document_store._azure_endpoint, EnvVarSecret)
assert document_store._index_name == "my_index"
assert document_store._embedding_dimension == 15
assert document_store._metadata_fields == {"Title": str, "Pages": int}
assert document_store._metadata_fields == {
"Title": SimpleField(name="Title", type="Edm.String", filterable=True),
"Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True),
}
assert document_store._index_creation_kwargs["encryption_key"] == encryption_key
assert document_store._index_creation_kwargs["analyzers"][0].name == "url-analyze"
assert document_store._index_creation_kwargs["analyzers"][0].token_filters == ["lowercase"]
Expand All @@ -218,7 +224,10 @@ def test_init(_mock_azure_search_client):

assert document_store._index_name == "my_index"
assert document_store._embedding_dimension == 15
assert document_store._metadata_fields == {"Title": str, "Pages": int}
assert document_store._metadata_fields == {
"Title": SimpleField(name="Title", type="Edm.String", filterable=True),
"Pages": SimpleField(name="Pages", type="Edm.Int32", filterable=True),
}
assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH


Expand All @@ -228,7 +237,6 @@ def test_init(_mock_azure_search_client):
reason="Missing AZURE_AI_SEARCH_ENDPOINT or AZURE_AI_SEARCH_API_KEY.",
)
class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest):

def test_write_documents(self, document_store: AzureAISearchDocumentStore):
docs = [Document(id="1")]
assert document_store.write_documents(docs) == 1
Expand Down Expand Up @@ -281,7 +289,6 @@ def _random_embeddings(n):
indirect=True,
)
class TestFilters(FilterDocumentsTest):

# Overriding to change "date" to compatible ISO 8601 format
@pytest.fixture
def filterable_docs(self) -> List[Document]:
Expand Down
Loading