Skip to content

Commit 183f20f

Browse files
committed
fix author matching and handling of filtering
1 parent cf1dba1 commit 183f20f

File tree

5 files changed

+16
-19
lines changed

5 files changed

+16
-19
lines changed

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ dependencies = [
1616
"boto3",
1717
"cohere",
1818
"elasticsearch >= 8.5",
19-
"elasticsearch-dsl",
2019
"fastapi <= 0.112.0",
2120
"fastapi-pagination",
2221
"httpx",

src/scholarag/app/dependencies.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import logging
55
from enum import Enum
66
from functools import cache
7-
from typing import Annotated, AsyncIterator
7+
from typing import Annotated, Any, AsyncIterator
88

9-
from elasticsearch_dsl import Q, Search
109
from fastapi import Depends, HTTPException, Query, Request
1110
from fastapi.security import HTTPBearer
1211
from httpx import AsyncClient, HTTPStatusError
@@ -202,24 +201,22 @@ def get_query_from_params(
202201
pattern=r"^\d{4}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$",
203202
),
204203
] = None,
205-
) -> dict[str, str] | None:
204+
) -> dict[str, dict[str, list[dict[str, Any]]]] | None:
206205
"""Get the query parameters and generate an ES query for filtering."""
207-
search = Search()
206+
query: dict[str, dict[str, list[dict[str, Any]]]] = {"bool": {"must": []}}
208207
if article_types:
209-
search = search.query(Q("terms", article_type=article_types))
208+
query["bool"]["must"].append({"terms": {"article_type": article_types}})
210209
if authors:
211-
search = search.query(Q("terms", authors=authors))
210+
query["bool"]["must"].append({"terms": {"authors.keyword": authors}})
212211
if journals:
213-
search = search.query(Q("terms", journal=journals))
212+
query["bool"]["must"].append({"terms": {"journal": journals}})
214213
if date_from:
215-
search = search.query(Q("range", date={"gte": date_from}))
214+
query["bool"]["must"].append({"range": {"date": {"gte": date_from}}})
216215
if date_to:
217-
search = search.query(Q("range", date={"lte": date_to}))
216+
query["bool"]["must"].append({"range": {"date": {"lte": date_to}}})
218217

219-
logger.info(
220-
f"Searching the database with the query {json.dumps(search.to_dict())}."
221-
)
222-
return None if not search.to_dict() else search.to_dict()["query"]
218+
logger.info(f"Searching the database with the query {json.dumps(query)}.")
219+
return None if not query["bool"]["must"] else query
223220

224221

225222
class ErrorCode(Enum):

src/scholarag/app/routers/retrieval.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ async def article_count(
250250

251251
# If further filters, append them
252252
if filter_query:
253-
query["query"]["bool"]["must"].append(filter_query)
253+
query["query"]["bool"]["must"].extend(filter_query["bool"]["must"])
254254

255255
# Aggregation query.
256256
aggs = {
@@ -390,7 +390,7 @@ async def article_listing(
390390

391391
# If further filters, append them
392392
if filter_query:
393-
query["query"]["bool"]["must"].append(filter_query)
393+
query["query"]["bool"]["must"].extend(filter_query["bool"]["must"])
394394

395395
aggs: dict[str, Any] = {
396396
"relevant_ids": {
@@ -410,8 +410,9 @@ async def article_listing(
410410
aggs["relevant_ids"]["aggs"]["score"]["max"] = {"field": "date"}
411411

412412
results = await ds_client.search(
413-
index=settings.db.index_paragraphs, query=query, size=78, aggs=aggs
413+
index=settings.db.index_paragraphs, query=query, size=0, aggs=aggs
414414
)
415+
415416
logger.info(f"unique article retrieval took: {time.time() - start}s")
416417

417418
docs = [

tests/app/test_dependencies.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_get_query_from_params():
184184
"bool": {
185185
"must": [
186186
{"terms": {"article_type": ["publication", "review"]}},
187-
{"terms": {"authors": ["Guy Manderson", "Joe Guy"]}},
187+
{"terms": {"authors.keyword": ["Guy Manderson", "Joe Guy"]}},
188188
{"terms": {"journal": ["1111-1111"]}},
189189
{"range": {"date": {"gte": "2020-01-01"}}},
190190
{"range": {"date": {"lte": "2020-01-02"}}},

tests/app/test_qa.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_generative_qa(app_client, mock_http_calls):
4141
"bool": {
4242
"must": [
4343
{"terms": {"article_type": params["article_types"]}},
44-
{"terms": {"authors": params["authors"]}},
44+
{"terms": {"authors.keyword": params["authors"]}},
4545
{"range": {"date": {"gte": params["date_from"]}}},
4646
{"range": {"date": {"lte": params["date_to"]}}},
4747
]

0 commit comments

Comments
 (0)