diff --git a/.env.example b/.env.example index b9505de9..03baec27 100644 --- a/.env.example +++ b/.env.example @@ -37,4 +37,5 @@ NEON_API_TOKEN=y NEON_API_BASE_URL=https://data.neonscience.org/api/v0 NERSC_USERNAME=replaceme -ORCID_CLIENT_ID=replaceme \ No newline at end of file +ORCID_CLIENT_ID=replaceme +ORCID_CLIENT_SECRET=replaceme \ No newline at end of file diff --git a/nmdc_runtime/api/core/auth.py b/nmdc_runtime/api/core/auth.py index 5e4d7c1c..85c5d5a5 100644 --- a/nmdc_runtime/api/core/auth.py +++ b/nmdc_runtime/api/core/auth.py @@ -6,17 +6,25 @@ from fastapi.exceptions import HTTPException from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.param_functions import Form -from fastapi.security import OAuth2, HTTPBasic, HTTPBasicCredentials +from fastapi.security import ( + OAuth2, + HTTPBasic, + HTTPBasicCredentials, + HTTPBearer, + HTTPAuthorizationCredentials, +) from fastapi.security.utils import get_authorization_scheme_param from jose import JWTError, jwt from passlib.context import CryptContext from pydantic import BaseModel +from starlette import status from starlette.requests import Request from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED SECRET_KEY = os.getenv("JWT_SECRET_KEY") ALGORITHM = "HS256" ORCID_CLIENT_ID = os.getenv("ORCID_CLIENT_ID") +ORCID_CLIENT_SECRET = os.getenv("ORCID_CLIENT_SECRET") # https://orcid.org/.well-known/openid-configuration # XXX do we want to live-load this? @@ -129,15 +137,24 @@ async def __call__(self, request: Request) -> Optional[str]: tokenUrl="token", auto_error=False ) +bearer_scheme = HTTPBearer(scheme_name="bearerAuth", auto_error=False) + async def basic_credentials(req: Request): return await HTTPBasic(auto_error=False)(req) +async def bearer_credentials(req: Request): + return await HTTPBearer(scheme_name="bearerAuth", auto_error=False)(req) + + class OAuth2PasswordOrClientCredentialsRequestForm: def __init__( self, basic_creds: Optional[HTTPBasicCredentials] = Depends(basic_credentials), + bearer_creds: Optional[HTTPAuthorizationCredentials] = Depends( + bearer_credentials + ), grant_type: str = Form(None, regex="^password$|^client_credentials$"), username: Optional[str] = Form(None), password: Optional[str] = Form(None), @@ -145,14 +162,18 @@ def __init__( client_id: Optional[str] = Form(None), client_secret: Optional[str] = Form(None), ): - if grant_type == "password" and (username is None or password is None): + if bearer_creds: + self.grant_type = "client_credentials" + self.username, self.password = None, None + self.scopes = scope.split() + self.client_id = bearer_creds.credentials + self.client_secret = None + elif grant_type == "password" and (username is None or password is None): raise HTTPException( status_code=HTTP_400_BAD_REQUEST, detail="grant_type password requires username and password", ) - if grant_type == "client_credentials" and ( - client_id is None or client_secret is None - ): + elif grant_type == "client_credentials" and (client_id is None): if basic_creds: client_id = basic_creds.username client_secret = basic_creds.password diff --git a/nmdc_runtime/api/endpoints/users.py b/nmdc_runtime/api/endpoints/users.py index 4f79e752..c174092c 100644 --- a/nmdc_runtime/api/endpoints/users.py +++ b/nmdc_runtime/api/endpoints/users.py @@ -2,7 +2,9 @@ from datetime import timedelta import pymongo.database +import requests from fastapi import Depends, APIRouter, HTTPException, status +from fastapi.openapi.docs import get_swagger_ui_html from jose import jws, JWTError from starlette.requests import Request from starlette.responses import HTMLResponse, RedirectResponse @@ -16,6 +18,7 @@ ORCID_JWK, ORCID_JWS_VERITY_ALGORITHM, credentials_exception, + ORCID_CLIENT_SECRET, ) from nmdc_runtime.api.core.auth import get_password_hash from nmdc_runtime.api.core.util import generate_secret @@ -32,43 +35,28 @@ router = APIRouter() -@router.get("/orcid_authorize") -async def orcid_authorize(): - """NOTE: You want to load /orcid_authorize directly in your web browser to initiate the login redirect flow.""" - return RedirectResponse( - f"https://orcid.org/oauth/authorize?client_id={ORCID_CLIENT_ID}" - "&response_type=token&scope=openid&" - f"redirect_uri={BASE_URL_EXTERNAL}/orcid_token" - ) - - -@router.get("/orcid_token") -async def redirect_uri_for_orcid_token(req: Request): - """ - Returns a web page that will display a user's orcid jwt token for copy/paste. - - This route is loaded by orcid.org after a successful orcid user login. - """ - return HTMLResponse( - """ -
- - - - - - - """ +@router.get("/orcid_code", response_class=RedirectResponse) +async def receive_orcid_code(request: Request, code: str, state: str | None = None): + rv = requests.post( + "https://orcid.org/oauth/token", + data=( + f"client_id={ORCID_CLIENT_ID}&client_secret={ORCID_CLIENT_SECRET}&" + f"grant_type=authorization_code&code={code}&redirect_uri={BASE_URL_EXTERNAL}/orcid_code" + ), + headers={ + "Content-type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }, ) + token_response = rv.json() + response = RedirectResponse(state or request.url_for("custom_swagger_ui_html")) + for key in ["user_orcid", "user_name", "user_id_token"]: + response.set_cookie( + key=key, + value=token_response[key.replace("user_", "")], + max_age=2592000, + ) + return response @router.post("/token", response_model=Token) diff --git a/nmdc_runtime/api/main.py b/nmdc_runtime/api/main.py index e8fe4015..e2107595 100644 --- a/nmdc_runtime/api/main.py +++ b/nmdc_runtime/api/main.py @@ -1,19 +1,27 @@ import os +import re from contextlib import asynccontextmanager from importlib import import_module from importlib.metadata import version +from typing import Annotated import fastapi +import requests import uvicorn -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, FastAPI, Cookie from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.staticfiles import StaticFiles from setuptools_scm import get_version from starlette import status -from starlette.responses import RedirectResponse +from starlette.responses import RedirectResponse, HTMLResponse from nmdc_runtime.api.analytics import Analytics -from nmdc_runtime.util import all_docs_have_unique_id, ensure_unique_id_indexes -from nmdc_runtime.api.core.auth import get_password_hash +from nmdc_runtime.util import ( + ensure_unique_id_indexes, + REPO_ROOT_DIR, +) +from nmdc_runtime.api.core.auth import get_password_hash, ORCID_CLIENT_ID from nmdc_runtime.api.db.mongo import ( get_mongo_db, ) @@ -356,10 +364,15 @@ async def get_versions(): "\n\n" "Dependency versions:\n\n" f'nmdc-schema={version("nmdc_schema")}\n\n' - "Documentation" + "Documentation\n\n" + '