28
28
from scholarag .document_stores import AsyncBaseSearch
29
29
from scholarag .retrieve_metadata import MetaDataRetriever
30
30
from scholarag .services import CohereRerankingService , RetrievalService
31
+ from scholarag .utils import build_search_query
31
32
32
33
router = APIRouter (
33
34
prefix = "/retrieval" , tags = ["Retrieval" ], dependencies = [Depends (get_user_id )]
@@ -194,6 +195,7 @@ async def article_count(
194
195
)
195
196
),
196
197
] = None ,
198
+ resolve_hierarchy : bool = False ,
197
199
) -> ArticleCountResponse :
198
200
"""Article count based on keyword matching.
199
201
\f
@@ -218,59 +220,12 @@ async def article_count(
218
220
start = time .time ()
219
221
logger .info ("Finding unique articles matching the query ..." )
220
222
221
- if not topics and not regions :
222
- raise HTTPException (
223
- status_code = 422 , detail = "Please provide at least one region or topic."
224
- )
225
-
226
- # Match the keywords on abstract + title.
227
- topic_query = (
228
- [
229
- {
230
- "multi_match" : {
231
- "query" : topic ,
232
- "type" : "phrase" ,
233
- "fields" : ["title" , "text" ],
234
- }
235
- }
236
- for topic in topics
237
- ]
238
- if topics is not None
239
- else []
240
- )
241
- regions_query = (
242
- [
243
- {
244
- "bool" : {
245
- "should" : [
246
- {
247
- "multi_match" : {
248
- "query" : region ,
249
- "type" : "phrase" ,
250
- "fields" : ["title" , "text" ],
251
- }
252
- }
253
- for region in regions
254
- ]
255
- }
256
- }
257
- ]
258
- if regions is not None
259
- else []
223
+ query = build_search_query (
224
+ topics = topics ,
225
+ regions = regions ,
226
+ filter_query = filter_query ,
227
+ resolve_hierarchy = resolve_hierarchy ,
260
228
)
261
- filter_query_list = filter_query ["bool" ]["must" ] if filter_query else []
262
-
263
- query : dict [str , Any ] = {
264
- "query" : {
265
- "bool" : {
266
- "must" : [
267
- * topic_query ,
268
- * regions_query ,
269
- * filter_query_list ,
270
- ]
271
- }
272
- }
273
- }
274
229
275
230
# Aggregation query.
276
231
aggs = {
@@ -348,6 +303,7 @@ async def article_listing(
348
303
)
349
304
),
350
305
] = False ,
306
+ resolve_hierarchy : bool = False ,
351
307
) -> Page [ArticleMetadata ]:
352
308
"""Article id listing based on keyword matching.
353
309
\f
@@ -380,59 +336,12 @@ async def article_listing(
380
336
start = time .time ()
381
337
logger .info ("Finding unique articles matching the query ..." )
382
338
383
- if not topics and not regions :
384
- raise HTTPException (
385
- status_code = 422 , detail = "Please provide at least one region or topic."
386
- )
387
-
388
- # Match the keywords on abstract + title.
389
- topic_query = (
390
- [
391
- {
392
- "multi_match" : {
393
- "query" : topic ,
394
- "type" : "phrase" ,
395
- "fields" : ["title" , "text" ],
396
- }
397
- }
398
- for topic in topics
399
- ]
400
- if topics is not None
401
- else []
402
- )
403
- regions_query = (
404
- [
405
- {
406
- "bool" : {
407
- "should" : [
408
- {
409
- "multi_match" : {
410
- "query" : region ,
411
- "type" : "phrase" ,
412
- "fields" : ["title" , "text" ],
413
- }
414
- }
415
- for region in regions
416
- ]
417
- }
418
- }
419
- ]
420
- if regions is not None
421
- else []
339
+ query = build_search_query (
340
+ topics = topics ,
341
+ regions = regions ,
342
+ filter_query = filter_query ,
343
+ resolve_hierarchy = resolve_hierarchy ,
422
344
)
423
- filter_query_list = filter_query ["bool" ]["must" ] if filter_query else []
424
-
425
- query : dict [str , Any ] = {
426
- "query" : {
427
- "bool" : {
428
- "must" : [
429
- * topic_query ,
430
- * regions_query ,
431
- * filter_query_list ,
432
- ]
433
- }
434
- }
435
- }
436
345
437
346
aggs : dict [str , Any ] = {
438
347
"relevant_ids" : {
0 commit comments