Skip to content

Commit 2ebe8e6

Browse files
committed
first working prototype
1 parent adb0ec2 commit 2ebe8e6

File tree

6 files changed

+425
-105
lines changed

6 files changed

+425
-105
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- resolve hierarchy for article listing / count.
12+
1013
## [0.0.10] - 26.02.2025
1114

1215
### Changed

brainregion_hierarchy.json

+1
Large diffs are not rendered by default.

src/scholarag/app/dependencies.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def get_ds_client(
118118
port=settings.db.port,
119119
user=settings.db.user,
120120
password=password,
121-
use_ssl_and_verify_certs=True,
121+
use_ssl_and_verify_certs=False,
122122
)
123123
yield ds_client
124124
finally:

src/scholarag/app/routers/retrieval.py

+13-104
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from scholarag.document_stores import AsyncBaseSearch
2929
from scholarag.retrieve_metadata import MetaDataRetriever
3030
from scholarag.services import CohereRerankingService, RetrievalService
31+
from scholarag.utils import build_search_query
3132

3233
router = APIRouter(
3334
prefix="/retrieval", tags=["Retrieval"], dependencies=[Depends(get_user_id)]
@@ -194,6 +195,7 @@ async def article_count(
194195
)
195196
),
196197
] = None,
198+
resolve_hierarchy: bool = False,
197199
) -> ArticleCountResponse:
198200
"""Article count based on keyword matching.
199201
\f
@@ -218,59 +220,12 @@ async def article_count(
218220
start = time.time()
219221
logger.info("Finding unique articles matching the query ...")
220222

221-
if not topics and not regions:
222-
raise HTTPException(
223-
status_code=422, detail="Please provide at least one region or topic."
224-
)
225-
226-
# Match the keywords on abstract + title.
227-
topic_query = (
228-
[
229-
{
230-
"multi_match": {
231-
"query": topic,
232-
"type": "phrase",
233-
"fields": ["title", "text"],
234-
}
235-
}
236-
for topic in topics
237-
]
238-
if topics is not None
239-
else []
240-
)
241-
regions_query = (
242-
[
243-
{
244-
"bool": {
245-
"should": [
246-
{
247-
"multi_match": {
248-
"query": region,
249-
"type": "phrase",
250-
"fields": ["title", "text"],
251-
}
252-
}
253-
for region in regions
254-
]
255-
}
256-
}
257-
]
258-
if regions is not None
259-
else []
223+
query = build_search_query(
224+
topics=topics,
225+
regions=regions,
226+
filter_query=filter_query,
227+
resolve_hierarchy=resolve_hierarchy,
260228
)
261-
filter_query_list = filter_query["bool"]["must"] if filter_query else []
262-
263-
query: dict[str, Any] = {
264-
"query": {
265-
"bool": {
266-
"must": [
267-
*topic_query,
268-
*regions_query,
269-
*filter_query_list,
270-
]
271-
}
272-
}
273-
}
274229

275230
# Aggregation query.
276231
aggs = {
@@ -348,6 +303,7 @@ async def article_listing(
348303
)
349304
),
350305
] = False,
306+
resolve_hierarchy: bool = False,
351307
) -> Page[ArticleMetadata]:
352308
"""Article id listing based on keyword matching.
353309
\f
@@ -380,59 +336,12 @@ async def article_listing(
380336
start = time.time()
381337
logger.info("Finding unique articles matching the query ...")
382338

383-
if not topics and not regions:
384-
raise HTTPException(
385-
status_code=422, detail="Please provide at least one region or topic."
386-
)
387-
388-
# Match the keywords on abstract + title.
389-
topic_query = (
390-
[
391-
{
392-
"multi_match": {
393-
"query": topic,
394-
"type": "phrase",
395-
"fields": ["title", "text"],
396-
}
397-
}
398-
for topic in topics
399-
]
400-
if topics is not None
401-
else []
402-
)
403-
regions_query = (
404-
[
405-
{
406-
"bool": {
407-
"should": [
408-
{
409-
"multi_match": {
410-
"query": region,
411-
"type": "phrase",
412-
"fields": ["title", "text"],
413-
}
414-
}
415-
for region in regions
416-
]
417-
}
418-
}
419-
]
420-
if regions is not None
421-
else []
339+
query = build_search_query(
340+
topics=topics,
341+
regions=regions,
342+
filter_query=filter_query,
343+
resolve_hierarchy=resolve_hierarchy,
422344
)
423-
filter_query_list = filter_query["bool"]["must"] if filter_query else []
424-
425-
query: dict[str, Any] = {
426-
"query": {
427-
"bool": {
428-
"must": [
429-
*topic_query,
430-
*regions_query,
431-
*filter_query_list,
432-
]
433-
}
434-
}
435-
}
436345

437346
aggs: dict[str, Any] = {
438347
"relevant_ids": {

0 commit comments

Comments
 (0)