Skip to content

Commit cf1dba1

Browse files
committed
fix last tests
1 parent 2960116 commit cf1dba1

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

src/scholarag/app/routers/retrieval.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ async def article_count(
250250

251251
# If further filters, append them
252252
if filter_query:
253-
query["query"]["bool"]["must"].append(filter_query["bool"]["must"])
253+
query["query"]["bool"]["must"].append(filter_query)
254254

255255
# Aggregation query.
256256
aggs = {
@@ -389,8 +389,8 @@ async def article_listing(
389389
)
390390

391391
# If further filters, append them
392-
if filter_query and filter_query.get("bool"):
393-
query["query"]["bool"]["must"].append(filter_query["bool"]["must"])
392+
if filter_query:
393+
query["query"]["bool"]["must"].append(filter_query)
394394

395395
aggs: dict[str, Any] = {
396396
"relevant_ids": {

tests/app/test_retrieval.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def test_retrieval_no_answer_code_1(app_client):
112112
(None, ["3", "4"], None, None, 11),
113113
(["3", "4"], None, None, None, 11),
114114
(None, ["3 4"], None, None, 11),
115-
# (None, None, None, None, 19),
116-
# (None, None, "2022-12-01", None, 5),
117-
# (None, None, None, "2022-01-01", 6),
118-
# (None, None, "2022-03-01", "2022-06-01", 17),
115+
(["0 1 2 3 4 5 6 7 8 9"], None, None, None, 19),
116+
(["0 1 2 3 4 5 6 7 8 9"], None, "2022-12-01", None, 5),
117+
(["0 1 2 3 4 5 6 7 8 9"], None, None, "2022-01-01", 6),
118+
(["0 1 2 3 4 5 6 7 8 9"], None, "2022-03-01", "2022-06-01", 17),
119119
],
120120
)
121121
async def test_article_count(
@@ -317,13 +317,12 @@ async def test_article_listing(get_testing_async_ds_client, mock_http_calls):
317317
response = response.json()
318318

319319
assert sorted([resp["article_id"] for resp in response["items"]]) == [
320+
"10",
320321
"11",
321-
"12",
322-
"13",
323-
"17",
322+
"18",
324323
"19",
325-
"27",
326-
"37",
324+
"21",
325+
"31",
327326
"47",
328327
"57",
329328
"7",
@@ -347,16 +346,16 @@ async def test_article_listing(get_testing_async_ds_client, mock_http_calls):
347346
response = response.json()
348347

349348
assert sorted([resp["article_id"] for resp in response["items"]]) == [
350-
"11",
349+
"1",
351350
"12",
352351
"13",
352+
"14",
353+
"15",
354+
"16",
353355
"17",
354-
"19",
355356
"27",
356357
"37",
357-
"47",
358-
"57",
359-
"7",
358+
"51",
360359
]
361360
expected_keys = set(ArticleMetadata.model_json_schema()["properties"].keys())
362361
for d in response["items"]:

0 commit comments

Comments
 (0)