Skip to content

Commit 3058aff

Browse files
authored
Add lock to swagger UI (#3)
* Add lock to swagger UI * Fic CI * Fic CI2 * fix lint * fix CI * fix SSL certs
1 parent f405e07 commit 3058aff

12 files changed

+105
-52
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- Bearer token insertion in FastAPI UI.
12+
1013
### Fixed
1114
- Limit query size.
1215
- Add OBI copyright.

src/scholarag/app/config.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,31 @@
44
from typing import Literal
55

66
from dotenv import dotenv_values
7-
from fastapi.openapi.models import OAuthFlowPassword, OAuthFlows
8-
from pydantic import BaseModel, ConfigDict, SecretStr, model_validator
7+
from pydantic import BaseModel, ConfigDict, SecretStr
98
from pydantic_settings import BaseSettings, SettingsConfigDict
10-
from typing_extensions import Self
119

1210
from scholarag.generative_question_answering import MESSAGES
1311

1412

1513
class SettingsKeycloak(BaseModel):
1614
"""Class retrieving keycloak info for authorization."""
1715

18-
issuer: str | None = None
16+
issuer: str = "https://openbluebrain.com/auth/realms/SBO"
1917
validate_token: bool = False
18+
# Useful only for service account (dev)
19+
client_id: str | None = None
20+
username: str | None = None
21+
password: SecretStr | None = None
2022

2123
model_config = ConfigDict(frozen=True)
2224

23-
@model_validator(mode="after")
24-
def check_issuer(self) -> Self:
25-
"""Check if there is an issuer provided for authentication."""
26-
if self.validate_token and (self.issuer is None):
27-
raise ValueError("No issuer provided.")
28-
return self
29-
3025
@property
31-
def token_endpoint(self) -> str:
26+
def token_endpoint(self) -> str | None:
3227
"""Define the token endpoint."""
33-
return f"{self.issuer}/protocol/openid-connect/token"
28+
if self.validate_token:
29+
return f"{self.issuer}/protocol/openid-connect/token"
30+
else:
31+
return None
3432

3533
@property
3634
def user_info_endpoint(self) -> str | None:
@@ -41,13 +39,14 @@ def user_info_endpoint(self) -> str | None:
4139
return None
4240

4341
@property
44-
def flows(self) -> OAuthFlows:
45-
"""Define the flow to override Fastapi's one."""
46-
return OAuthFlows(
47-
password=OAuthFlowPassword(
48-
tokenUrl=self.token_endpoint,
49-
),
50-
)
42+
def server_url(self) -> str:
43+
"""Server url."""
44+
return self.issuer.split("/auth")[0] + "/auth/"
45+
46+
@property
47+
def realm(self) -> str:
48+
"""Realm."""
49+
return self.issuer.rpartition("/realms/")[-1]
5150

5251

5352
class SettingsDB(BaseModel):

src/scholarag/app/dependencies.py

+33-20
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing import Annotated, AsyncIterator
88

99
from elasticsearch_dsl import Q, Search
10-
from fastapi import Depends, HTTPException, Query
11-
from fastapi.security import OAuth2PasswordBearer
10+
from fastapi import Depends, HTTPException, Query, Request
11+
from fastapi.security import HTTPBearer
1212
from httpx import AsyncClient, HTTPStatusError
1313
from openai import AsyncOpenAI
1414
from pydantic import constr
@@ -27,10 +27,17 @@
2727

2828
journal_constraints = constr(pattern=r"^\d{4}-\d{3}[0-9X]$")
2929

30-
auth = OAuth2PasswordBearer(
31-
tokenUrl="/token", # Will be overriden
32-
auto_error=False,
33-
)
30+
31+
class HTTPBearerDirect(HTTPBearer):
32+
"""HTTPBearer class that returns directly the token in the call."""
33+
34+
async def __call__(self, request: Request) -> str | None: # type: ignore
35+
"""Intercept the bearer token in the headers."""
36+
auth_credentials = await super().__call__(request)
37+
return auth_credentials.credentials if auth_credentials else None
38+
39+
40+
auth = HTTPBearerDirect(auto_error=False)
3441

3542

3643
@cache
@@ -55,24 +62,30 @@ async def get_httpx_client(
5562

5663

5764
async def get_user_id(
58-
token: Annotated[str | None, Depends(auth)],
65+
request: Request,
66+
token: Annotated[str, Depends(auth)],
5967
settings: Annotated[Settings, Depends(get_settings)],
6068
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
6169
) -> str:
6270
"""Validate JWT token and returns user ID."""
63-
if settings.keycloak.validate_token and settings.keycloak.user_info_endpoint:
64-
try:
65-
response = await httpx_client.get(
66-
settings.keycloak.user_info_endpoint,
67-
headers={"Authorization": f"Bearer {token}"},
68-
)
69-
response.raise_for_status()
70-
user_info = response.json()
71-
return user_info["sub"]
72-
except HTTPStatusError:
73-
raise HTTPException(
74-
status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token."
75-
)
71+
if hasattr(request.state, "sub"):
72+
return request.state.sub
73+
if settings.keycloak.validate_token:
74+
if settings.keycloak.user_info_endpoint:
75+
try:
76+
response = await httpx_client.get(
77+
settings.keycloak.user_info_endpoint,
78+
headers={"Authorization": f"Bearer {token}"},
79+
)
80+
response.raise_for_status()
81+
user_info = response.json()
82+
return user_info["sub"]
83+
except HTTPStatusError:
84+
raise HTTPException(
85+
status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token."
86+
)
87+
else:
88+
raise HTTPException(status_code=404, detail="user info url not provided.")
7689
else:
7790
return "dev"
7891

src/scholarag/app/middleware.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,13 @@ async def get_and_set_cache(
250250
httpx_client = await anext(get_httpx_client(settings))
251251
# If raises HTTPException return error as json.
252252
try:
253-
await get_user_id(token=token, settings=settings, httpx_client=httpx_client)
253+
sub = await get_user_id(
254+
request=request,
255+
token=token, # type: ignore
256+
settings=settings,
257+
httpx_client=httpx_client,
258+
)
259+
request.state.sub = sub
254260
except HTTPException as e:
255261
return JSONResponse(status_code=e.status_code, content=e.detail)
256262

src/scholarag/app/routers/qa.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_reranker,
2323
get_rts,
2424
get_settings,
25+
get_user_id,
2526
)
2627
from scholarag.app.schemas import (
2728
GenerativeQARequest,
@@ -36,7 +37,9 @@
3637
from scholarag.retrieve_metadata import MetaDataRetriever
3738
from scholarag.services import CohereRerankingService, RetrievalService
3839

39-
router = APIRouter(prefix="/qa", tags=["Question answering"])
40+
router = APIRouter(
41+
prefix="/qa", tags=["Question answering"], dependencies=[Depends(get_user_id)]
42+
)
4043

4144
logger = logging.getLogger(__name__)
4245

src/scholarag/app/routers/retrieval.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_reranker,
1818
get_rts,
1919
get_settings,
20+
get_user_id,
2021
)
2122
from scholarag.app.schemas import (
2223
ArticleCountResponse,
@@ -28,7 +29,9 @@
2829
from scholarag.retrieve_metadata import MetaDataRetriever
2930
from scholarag.services import CohereRerankingService, RetrievalService
3031

31-
router = APIRouter(prefix="/retrieval", tags=["Retrieval"])
32+
router = APIRouter(
33+
prefix="/retrieval", tags=["Retrieval"], dependencies=[Depends(get_user_id)]
34+
)
3235

3336
logger = logging.getLogger(__name__)
3437

src/scholarag/app/routers/suggestions.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from fastapi import APIRouter, Depends, HTTPException
99

1010
from scholarag.app.config import Settings
11-
from scholarag.app.dependencies import ErrorCode, get_ds_client, get_settings
11+
from scholarag.app.dependencies import (
12+
ErrorCode,
13+
get_ds_client,
14+
get_settings,
15+
get_user_id,
16+
)
1217
from scholarag.app.schemas import (
1318
ArticleTypeSuggestionResponse,
1419
AuthorSuggestionRequest,
@@ -19,7 +24,9 @@
1924
from scholarag.document_stores import AsyncBaseSearch
2025
from scholarag.utils import format_issn
2126

22-
router = APIRouter(prefix="/suggestions", tags=["Suggestions"])
27+
router = APIRouter(
28+
prefix="/suggestions", tags=["Suggestions"], dependencies=[Depends(get_user_id)]
29+
)
2330

2431
logger = logging.getLogger(__name__)
2532

src/scholarag/scripts/pmc_parse_and_upload.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async def run(
181181
s3_iterator = s3_paginator.paginate(Bucket="pmc-oa-opendata", Prefix=prefix)
182182
logger.info("Filtering interesting articles.")
183183
filtered_iterator = s3_iterator.search(
184-
f"""Contents[?to_string(LastModified)>='\"{start_date.strftime('%Y-%m-%d %H:%M:%S%')}+00:00\"'
184+
f"""Contents[?to_string(LastModified)>='\"{start_date.strftime("%Y-%m-%d %H:%M:%S%")}+00:00\"'
185185
&& contains(Key, '.xml')]"""
186186
)
187187
finished = False

src/scholarag/scripts/pu_consumer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ async def run(
258258
for file, message in zip(files, batch):
259259
url = (
260260
parser_url
261-
+ f'/{message["MessageAttributes"]["Parser_Endpoint"]["StringValue"]}'
261+
+ f"/{message['MessageAttributes']['Parser_Endpoint']['StringValue']}"
262262
)
263263
result = await parser_service.arun(
264264
files=[file],
@@ -309,7 +309,7 @@ async def run(
309309
else:
310310
logger.info(
311311
f"[WORKER {worker_n}] Successfully deleted"
312-
f" {len(to_delete_from_q[i:min(i+10, len(to_delete_from_q))])} entries"
312+
f" {len(to_delete_from_q[i : min(i + 10, len(to_delete_from_q))])} entries"
313313
" from the queue."
314314
)
315315
except BaseException as e:

src/scholarag/scripts/pu_producer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ async def run(
142142
f" date: {start_date}."
143143
)
144144
filtered_iterator = s3_iterator.search(
145-
f"""Contents[?to_string(LastModified)>='\"{start_date.strftime('%Y-%m-%d %H:%M:%S%')}+00:00\"'
145+
f"""Contents[?to_string(LastModified)>='\"{start_date.strftime("%Y-%m-%d %H:%M:%S%")}+00:00\"'
146146
&& contains(Key, '.{file_extension if file_extension else ""}')]"""
147147
)
148148
finished = False

tests/app/test_dependencies.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import Mock
2+
13
import pytest
24
from fastapi.exceptions import HTTPException
35
from httpx import AsyncClient
@@ -66,11 +68,23 @@ async def test_get_user_id(httpx_mock, monkeypatch):
6668
json=fake_response,
6769
)
6870

71+
request = Mock()
6972
settings = Settings()
7073
client = AsyncClient()
7174
token = "eyJgreattoken"
72-
user = await get_user_id(token=token, settings=settings, httpx_client=client)
7375

76+
# Checks for the user sub in the request state.
77+
request.state.sub = "test_sub"
78+
user = await get_user_id(
79+
request=request, token=token, settings=settings, httpx_client=client
80+
)
81+
assert user == "test_sub"
82+
83+
# Check for the user sub from keycloack.
84+
delattr(request.state, "sub")
85+
user = await get_user_id(
86+
request=request, token=token, settings=settings, httpx_client=client
87+
)
7488
assert user == "12345"
7589

7690

@@ -92,8 +106,13 @@ async def test_get_user_id_error(httpx_mock, monkeypatch):
92106
client = AsyncClient()
93107
token = "eyJgreattoken"
94108

109+
request = Mock()
110+
delattr(request.state, "sub")
111+
95112
with pytest.raises(HTTPException) as err:
96-
await get_user_id(token=token, settings=settings, httpx_client=client)
113+
await get_user_id(
114+
request=request, token=token, settings=settings, httpx_client=client
115+
)
97116

98117
assert err.value.status_code == 401
99118
assert err.value.detail == "Invalid token."

tests/app/test_qa.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ async def streamed_response_no_answer(**kwargs):
355355
"don",
356356
"'t ",
357357
"know.",
358-
", " '"paragraphs": ',
358+
', "paragraphs": ',
359359
"[]}",
360360
]
361361
parsed = {}

0 commit comments

Comments
 (0)