Skip to content

Commit 0bb3c07

Browse files
committed
fix most tests (some left)
1 parent cf30e14 commit 0bb3c07

File tree

5 files changed

+100
-84
lines changed

5 files changed

+100
-84
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
### Added
1313
- Bearer token insertion in FastAPI UI.
1414

15+
### Changed
16+
- Better search algorithm for article count / listing.
17+
1518
### Fixed
1619
- Limit query size.
1720
- Add OBI copyright.

src/scholarag/app/dependencies.py

-14
Original file line numberDiff line numberDiff line change
@@ -176,20 +176,6 @@ async def get_reranker(
176176

177177

178178
def get_query_from_params(
179-
# topics: Annotated[
180-
# list[str] | None,
181-
# Query(
182-
# description="Keyword to be matched in text. AND matching (e.g. for TOPICS)."
183-
# ),
184-
# ] = None,
185-
# regions: Annotated[
186-
# list[str] | None,
187-
# Query(
188-
# description=(
189-
# "Keyword to be matched in text. OR matching (e.g. for BRAIN_REGIONS)."
190-
# )
191-
# ),
192-
# ] = None,
193179
article_types: Annotated[
194180
list[str] | None, Query(description="Article types allowed. OR matching")
195181
] = None,

src/scholarag/app/routers/retrieval.py

+43-34
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import time
5-
from typing import Annotated, Any
5+
from typing import Annotated, Any, Dict
66

77
from fastapi import APIRouter, Depends, HTTPException, Query
88
from fastapi_pagination import Page, paginate
@@ -220,32 +220,34 @@ async def article_count(
220220
)
221221
else:
222222
# Match the keywords on abstract + title.
223-
keywords = ([topic.split(" ") for topic in topics] or []) + (
224-
[region.split(" ") for region in regions] or []
225-
)
226-
query = {
223+
query: Dict[str, Any] = {
227224
"query": {
228225
"bool": {
229226
"must": [
230-
{
231-
"multi_match": {
232-
"query": wo,
233-
"fields": ["title", "text"],
234-
}
235-
}
236-
# {
237-
# "multi_match": {
238-
# "query": " ".join(regions),
239-
# "fields": ["title", "text"],
240-
# }
241-
# },
242-
for words in keywords
243-
for wo in words
227+
{"term": {"section": "Abstract"}},
244228
]
245229
}
246230
}
247231
}
248-
query["query"]["bool"]["must"].append({"term": {"section": "Abstract"}})
232+
if topics:
233+
query["query"]["bool"]["must"].append(
234+
{
235+
"multi_match": {
236+
"query": " ".join(topics),
237+
"fields": ["title", "text"],
238+
}
239+
}
240+
)
241+
if regions:
242+
query["query"]["bool"]["must"].append(
243+
{
244+
"multi_match": {
245+
"query": " ".join(regions),
246+
"fields": ["title", "text"],
247+
}
248+
}
249+
)
250+
249251
# If further filters, append them
250252
if filter_query:
251253
query["query"]["bool"]["must"].append(filter_query["bool"]["must"])
@@ -358,29 +360,36 @@ async def article_listing(
358360
status_code=422, detail="Please provide at least one region or topic."
359361
)
360362
else:
361-
keywords = ([topic.split(" ") for topic in topics] or []) + (
362-
[region.split(" ") for region in regions] or []
363-
)
364-
query = {
363+
query: Dict[str, Any] = {
365364
"query": {
366365
"bool": {
367366
"must": [
368-
{
369-
"multi_match": {
370-
"query": wo,
371-
"fields": ["title", "text"],
372-
}
373-
}
374-
for words in keywords
375-
for wo in words
367+
{"term": {"section": "Abstract"}},
376368
]
377369
}
378370
}
379371
}
380-
query["query"]["bool"]["must"].append({"term": {"section": "Abstract"}})
372+
if topics:
373+
query["query"]["bool"]["must"].append(
374+
{
375+
"multi_match": {
376+
"query": " ".join(topics),
377+
"fields": ["title", "text"],
378+
}
379+
}
380+
)
381+
if regions:
382+
query["query"]["bool"]["must"].append(
383+
{
384+
"multi_match": {
385+
"query": " ".join(regions),
386+
"fields": ["title", "text"],
387+
}
388+
}
389+
)
381390

382391
# If further filters, append them
383-
if filter_query:
392+
if filter_query and filter_query.get("bool"):
384393
query["query"]["bool"]["must"].append(filter_query["bool"]["must"])
385394

386395
aggs: dict[str, Any] = {

tests/app/test_dependencies.py

-13
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ async def test_get_reranker_cohere(monkeypatch):
174174

175175

176176
def test_get_query_from_params():
177-
topics = ["pyramidal cells", "retina"]
178-
regions = ["brain region", "thalamus"]
179177
article_types = ["publication", "review"]
180178
authors = ["Guy Manderson", "Joe Guy"]
181179
journals = ["1111-1111"]
@@ -185,15 +183,6 @@ def test_get_query_from_params():
185183
expected = {
186184
"bool": {
187185
"must": [
188-
{
189-
"query_string": {
190-
"default_field": "text",
191-
"query": (
192-
"((pyramidal AND cells) AND retina) AND ((brain AND region)"
193-
" OR thalamus)"
194-
),
195-
}
196-
},
197186
{"terms": {"article_type": ["publication", "review"]}},
198187
{"terms": {"authors": ["Guy Manderson", "Joe Guy"]}},
199188
{"terms": {"journal": ["1111-1111"]}},
@@ -204,8 +193,6 @@ def test_get_query_from_params():
204193
}
205194

206195
query = get_query_from_params(
207-
topics=topics,
208-
regions=regions,
209196
article_types=article_types,
210197
authors=authors,
211198
journals=journals,

tests/app/test_retrieval.py

+54-23
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def test_retrieval(app_client, retriever_k, mock_http_calls):
2222
fake_rts, _ = override_rts(has_context=True)
2323

2424
params = {
25-
"regions": ["thalamus", "Giant Hippopotamidae"],
2625
"journals": ["1234-5678"],
2726
"date_to": "2022-12-31",
2827
"query": "aaa",
@@ -31,12 +30,6 @@ def test_retrieval(app_client, retriever_k, mock_http_calls):
3130
expected_query = {
3231
"bool": {
3332
"must": [
34-
{
35-
"query_string": {
36-
"default_field": "text",
37-
"query": "(thalamus OR (Giant AND Hippopotamidae))",
38-
}
39-
},
4033
{"terms": {"journal": params["journals"]}},
4134
{"range": {"date": {"lte": params["date_to"]}}},
4235
]
@@ -113,16 +106,16 @@ def test_retrieval_no_answer_code_1(app_client):
113106
[
114107
(["1"], ["2"], None, None, 1),
115108
(["1"], ["1"], None, None, 10),
116-
(["1", "2"], ["3"], None, None, 0),
109+
(["1", "2"], ["3"], None, None, 2),
117110
(["1"], ["3", "4"], None, None, 2),
118-
(["1", "2"], ["3", "4"], None, None, 0),
111+
(["1", "2"], ["3", "4"], None, None, 3),
119112
(None, ["3", "4"], None, None, 11),
120-
(["3", "4"], None, None, None, 1),
121-
(None, ["3 4"], None, None, 1),
122-
(None, None, None, None, 19),
123-
(None, None, "2022-12-01", None, 5),
124-
(None, None, None, "2022-01-01", 6),
125-
(None, None, "2022-03-01", "2022-06-01", 17),
113+
(["3", "4"], None, None, None, 11),
114+
(None, ["3 4"], None, None, 11),
115+
# (None, None, None, None, 19),
116+
# (None, None, "2022-12-01", None, 5),
117+
# (None, None, None, "2022-01-01", 6),
118+
# (None, None, "2022-03-01", "2022-06-01", 17),
126119
],
127120
)
128121
async def test_article_count(
@@ -158,6 +151,7 @@ async def test_article_count(
158151
"paragraph_id": str(i),
159152
"article_id": n1 + n2, # 19 unique articles.
160153
"journal": "8765-4321",
154+
"section": "Abstract",
161155
"date": datetime(2022, i % 12 + 1, 1).strftime("%Y-%m-%d"),
162156
},
163157
}
@@ -229,7 +223,7 @@ async def test_article_listing(get_testing_async_ds_client, mock_http_calls):
229223
"pubmed_id": "PM1234",
230224
"authors": ["Nikemicsjanba"],
231225
"article_type": "code",
232-
"section": "abstract",
226+
"section": "Abstract",
233227
"date": datetime(2022, i % 12 + 1, 1).strftime("%Y-%m-%d"),
234228
},
235229
}
@@ -254,7 +248,18 @@ async def test_article_listing(get_testing_async_ds_client, mock_http_calls):
254248
assert response.status_code == 200
255249
response = response.json()
256250

257-
assert sorted([resp["article_id"] for resp in response["items"]]) == ["1", "16"]
251+
assert sorted([resp["article_id"] for resp in response["items"]]) == [
252+
"1",
253+
"11",
254+
"12",
255+
"13",
256+
"16",
257+
"2",
258+
"4",
259+
"6",
260+
"7",
261+
"9",
262+
]
258263
expected_keys = set(ArticleMetadata.model_json_schema()["properties"].keys())
259264
for d in response["items"]:
260265
assert set(d.keys()) == expected_keys
@@ -311,7 +316,18 @@ async def test_article_listing(get_testing_async_ds_client, mock_http_calls):
311316
assert response.status_code == 200
312317
response = response.json()
313318

314-
assert sorted([resp["article_id"] for resp in response["items"]]) == ["11"]
319+
assert sorted([resp["article_id"] for resp in response["items"]]) == [
320+
"11",
321+
"12",
322+
"13",
323+
"17",
324+
"19",
325+
"27",
326+
"37",
327+
"47",
328+
"57",
329+
"7",
330+
]
315331
expected_keys = set(ArticleMetadata.model_json_schema()["properties"].keys())
316332
for d in response["items"]:
317333
assert set(d.keys()) == expected_keys
@@ -330,17 +346,27 @@ async def test_article_listing(get_testing_async_ds_client, mock_http_calls):
330346
assert response.status_code == 200
331347
response = response.json()
332348

333-
assert sorted([resp["article_id"] for resp in response["items"]]) == ["17"]
349+
assert sorted([resp["article_id"] for resp in response["items"]]) == [
350+
"11",
351+
"12",
352+
"13",
353+
"17",
354+
"19",
355+
"27",
356+
"37",
357+
"47",
358+
"57",
359+
"7",
360+
]
334361
expected_keys = set(ArticleMetadata.model_json_schema()["properties"].keys())
335362
for d in response["items"]:
336363
assert set(d.keys()) == expected_keys
337364

338365

339366
@pytest.mark.asyncio
340-
async def test_article_listing_by_date(get_testing_async_ds_client, request):
367+
async def test_article_listing_by_date(get_testing_async_ds_client):
341368
ds_client, parameters = get_testing_async_ds_client
342369

343-
request.getfixturevalue("mock_http_calls")
344370
test_settings = Settings(
345371
db={
346372
"db_type": (
@@ -376,7 +402,7 @@ async def test_article_listing_by_date(get_testing_async_ds_client, request):
376402
"pubmed_id": "PM1234",
377403
"authors": ["Nikemicsjanba"],
378404
"article_type": "code",
379-
"section": "abstract",
405+
"section": "Abstract",
380406
"date": datetime(2022, i % 12 + 1, 1).strftime("%Y-%m-%d"),
381407
},
382408
}
@@ -387,7 +413,12 @@ async def test_article_listing_by_date(get_testing_async_ds_client, request):
387413
await ds_client.client.indices.refresh()
388414
app.dependency_overrides[get_ds_client] = lambda: ds_client
389415

390-
params = {"number_results": 50, "sort_by_date": True}
416+
params = {
417+
"number_results": 50,
418+
"topics": "paragraph",
419+
"region": "paragraph",
420+
"sort_by_date": True,
421+
}
391422

392423
async with AsyncClient(
393424
transport=ASGITransport(app=app), base_url="http://test"

0 commit comments

Comments
 (0)