Skip to content

Commit cf30e14

Browse files
committed
Temp
1 parent 876aaba commit cf30e14

File tree

2 files changed

+84
-67
lines changed

2 files changed

+84
-67
lines changed

src/scholarag/app/dependencies.py

+14-44
Original file line numberDiff line numberDiff line change
@@ -176,20 +176,20 @@ async def get_reranker(
176176

177177

178178
def get_query_from_params(
179-
topics: Annotated[
180-
list[str] | None,
181-
Query(
182-
description="Keyword to be matched in text. AND matching (e.g. for TOPICS)."
183-
),
184-
] = None,
185-
regions: Annotated[
186-
list[str] | None,
187-
Query(
188-
description=(
189-
"Keyword to be matched in text. OR matching (e.g. for BRAIN_REGIONS)."
190-
)
191-
),
192-
] = None,
179+
# topics: Annotated[
180+
# list[str] | None,
181+
# Query(
182+
# description="Keyword to be matched in text. AND matching (e.g. for TOPICS)."
183+
# ),
184+
# ] = None,
185+
# regions: Annotated[
186+
# list[str] | None,
187+
# Query(
188+
# description=(
189+
# "Keyword to be matched in text. OR matching (e.g. for BRAIN_REGIONS)."
190+
# )
191+
# ),
192+
# ] = None,
193193
article_types: Annotated[
194194
list[str] | None, Query(description="Article types allowed. OR matching")
195195
] = None,
@@ -219,36 +219,6 @@ def get_query_from_params(
219219
) -> dict[str, str] | None:
220220
"""Get the query parameters and generate an ES query for filtering."""
221221
search = Search()
222-
search_elems = []
223-
if topics:
224-
linked_tokens = [
225-
(
226-
" AND ".join(("(" + keyword + ")").split(" "))
227-
if len(keyword.split(" ")) >= 2
228-
else keyword
229-
)
230-
for keyword in topics
231-
]
232-
topics_bool = f"({' AND '.join(linked_tokens)})"
233-
search_elems.append(topics_bool)
234-
235-
if regions:
236-
linked_tokens = [
237-
(
238-
" AND ".join(("(" + keyword + ")").split(" "))
239-
if len(keyword.split(" ")) >= 2
240-
else keyword
241-
)
242-
for keyword in regions
243-
]
244-
regions_bool = f"({' OR '.join(linked_tokens)})"
245-
search_elems.append(regions_bool)
246-
247-
if topics or regions:
248-
q = Q(
249-
"query_string", default_field="text", query=f"{' AND '.join(search_elems)}"
250-
)
251-
search = search.query(q)
252222
if article_types:
253223
search = search.query(Q("terms", article_type=article_types))
254224
if authors:

src/scholarag/app/routers/retrieval.py

+70-23
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from typing import Annotated, Any
66

7-
from fastapi import APIRouter, Depends, HTTPException, Query, Request
7+
from fastapi import APIRouter, Depends, HTTPException, Query
88
from fastapi_pagination import Page, paginate
99
from httpx import AsyncClient
1010

@@ -177,9 +177,23 @@ async def retrieval(
177177

178178
@router.get("/article_count")
179179
async def article_count(
180-
request: Request,
181180
ds_client: Annotated[AsyncBaseSearch, Depends(get_ds_client)],
181+
filter_query: Annotated[dict[str, Any], Depends(get_query_from_params)],
182182
settings: Annotated[Settings, Depends(get_settings)],
183+
topics: Annotated[
184+
list[str] | None,
185+
Query(
186+
description="Keyword to be matched in text. AND matching (e.g. for TOPICS)."
187+
),
188+
] = None,
189+
regions: Annotated[
190+
list[str] | None,
191+
Query(
192+
description=(
193+
"Keyword to be matched in text. OR matching (e.g. for BRAIN_REGIONS)."
194+
)
195+
),
196+
] = None,
183197
) -> ArticleCountResponse:
184198
"""Article count based on keyword matching.
185199
\f
@@ -200,28 +214,41 @@ async def article_count(
200214
start = time.time()
201215
logger.info("Finding unique articles matching the query ...")
202216

203-
params = request.query_params
204-
topics = params.getlist("topics")
205-
regions = params.getlist("regions")
206-
207-
if topics or regions:
217+
if not topics and not regions:
218+
raise HTTPException(
219+
status_code=422, detail="Please provide at least one region or topic."
220+
)
221+
else:
222+
# Match the keywords on abstract + title.
223+
keywords = ([topic.split(" ") for topic in topics] or []) + (
224+
[region.split(" ") for region in regions] or []
225+
)
208226
query = {
209227
"query": {
210228
"bool": {
211229
"must": [
212230
{
213231
"multi_match": {
214-
"query": " ".join(topics) + " " + " ".join(regions),
232+
"query": wo,
215233
"fields": ["title", "text"],
216234
}
217-
},
218-
{"term": {"section": "Abstract"}},
235+
}
236+
# {
237+
# "multi_match": {
238+
# "query": " ".join(regions),
239+
# "fields": ["title", "text"],
240+
# }
241+
# },
242+
for words in keywords
243+
for wo in words
219244
]
220245
}
221246
}
222247
}
223-
else:
224-
query = {"query": {"match_all": {}}}
248+
query["query"]["bool"]["must"].append({"term": {"section": "Abstract"}})
249+
# If further filters, append them
250+
if filter_query:
251+
query["query"]["bool"]["must"].append(filter_query["bool"]["must"])
225252

226253
# Aggregation query.
227254
aggs = {
@@ -268,11 +295,24 @@ async def article_count(
268295
},
269296
)
270297
async def article_listing(
271-
request: Request,
272298
ds_client: Annotated[AsyncBaseSearch, Depends(get_ds_client)],
273299
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
274300
filter_query: Annotated[dict[str, Any], Depends(get_query_from_params)],
275301
settings: Annotated[Settings, Depends(get_settings)],
302+
topics: Annotated[
303+
list[str] | None,
304+
Query(
305+
description="Keyword to be matched in text. AND matching (e.g. for TOPICS)."
306+
),
307+
] = None,
308+
regions: Annotated[
309+
list[str] | None,
310+
Query(
311+
description=(
312+
"Keyword to be matched in text. OR matching (e.g. for BRAIN_REGIONS)."
313+
)
314+
),
315+
] = None,
276316
number_results: Annotated[
277317
int | None,
278318
Query(description="Number of results to return. Max 10 000.", ge=1, le=10_000),
@@ -313,28 +353,35 @@ async def article_listing(
313353
start = time.time()
314354
logger.info("Finding unique articles matching the query ...")
315355

316-
params = request.query_params
317-
topics = params.getlist("topics")
318-
regions = params.getlist("regions")
319-
320-
if topics or regions:
356+
if not topics and not regions:
357+
raise HTTPException(
358+
status_code=422, detail="Please provide at least one region or topic."
359+
)
360+
else:
361+
keywords = ([topic.split(" ") for topic in topics] or []) + (
362+
[region.split(" ") for region in regions] or []
363+
)
321364
query = {
322365
"query": {
323366
"bool": {
324367
"must": [
325368
{
326369
"multi_match": {
327-
"query": " ".join(topics) + " " + " ".join(regions),
370+
"query": wo,
328371
"fields": ["title", "text"],
329372
}
330-
},
331-
{"term": {"section": "Abstract"}},
373+
}
374+
for words in keywords
375+
for wo in words
332376
]
333377
}
334378
}
335379
}
336-
else:
337-
query = {"query": {"match_all": {}}}
380+
query["query"]["bool"]["must"].append({"term": {"section": "Abstract"}})
381+
382+
# If further filters, append them
383+
if filter_query:
384+
query["query"]["bool"]["must"].append(filter_query["bool"]["must"])
338385

339386
aggs: dict[str, Any] = {
340387
"relevant_ids": {

0 commit comments

Comments
 (0)