Skip to content

Commit

Permalink
cookie auth (#441)
Browse files Browse the repository at this point in the history
* feat: login-with-orcid link at top

for #423`

* new GH action to lint and reformat

* commit and push reformatting

closes #438

* fix

* quicken lint GH action

* test: for #439

* fix: author-ize

* fix: quote

* fix: autosetup remote

* fix: ensure HEAD ref for git push

* try .sha

* fix: arg sent to wrong step

* style: reformat

* inprogress: do not merge

* inprogress: do not merge

* inprogress: do not merge: add todo

* [do not merge] login w/o logout

* remove old orcid endpoints

* remove auth-action tags

* remove commented out orcid_cookie_test

* clean: abandon auth-action hack for now

---------

Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Jing Cao <jingcao.me@gmail.com>
  • Loading branch information
3 people authored Jan 22, 2024
1 parent 25a90c7 commit d989ebd
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 47 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ NEON_API_TOKEN=y
NEON_API_BASE_URL=https://data.neonscience.org/api/v0

NERSC_USERNAME=replaceme
ORCID_CLIENT_ID=replaceme
ORCID_CLIENT_ID=replaceme
ORCID_CLIENT_SECRET=replaceme
31 changes: 26 additions & 5 deletions nmdc_runtime/api/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,25 @@
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form
from fastapi.security import OAuth2, HTTPBasic, HTTPBasicCredentials
from fastapi.security import (
OAuth2,
HTTPBasic,
HTTPBasicCredentials,
HTTPBearer,
HTTPAuthorizationCredentials,
)
from fastapi.security.utils import get_authorization_scheme_param
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from starlette import status
from starlette.requests import Request
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED

SECRET_KEY = os.getenv("JWT_SECRET_KEY")
ALGORITHM = "HS256"
ORCID_CLIENT_ID = os.getenv("ORCID_CLIENT_ID")
ORCID_CLIENT_SECRET = os.getenv("ORCID_CLIENT_SECRET")

# https://orcid.org/.well-known/openid-configuration
# XXX do we want to live-load this?
Expand Down Expand Up @@ -129,30 +137,43 @@ async def __call__(self, request: Request) -> Optional[str]:
tokenUrl="token", auto_error=False
)

bearer_scheme = HTTPBearer(scheme_name="bearerAuth", auto_error=False)


async def basic_credentials(req: Request):
return await HTTPBasic(auto_error=False)(req)


async def bearer_credentials(req: Request):
return await HTTPBearer(scheme_name="bearerAuth", auto_error=False)(req)


class OAuth2PasswordOrClientCredentialsRequestForm:
def __init__(
self,
basic_creds: Optional[HTTPBasicCredentials] = Depends(basic_credentials),
bearer_creds: Optional[HTTPAuthorizationCredentials] = Depends(
bearer_credentials
),
grant_type: str = Form(None, regex="^password$|^client_credentials$"),
username: Optional[str] = Form(None),
password: Optional[str] = Form(None),
scope: str = Form(""),
client_id: Optional[str] = Form(None),
client_secret: Optional[str] = Form(None),
):
if grant_type == "password" and (username is None or password is None):
if bearer_creds:
self.grant_type = "client_credentials"
self.username, self.password = None, None
self.scopes = scope.split()
self.client_id = bearer_creds.credentials
self.client_secret = None
elif grant_type == "password" and (username is None or password is None):
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="grant_type password requires username and password",
)
if grant_type == "client_credentials" and (
client_id is None or client_secret is None
):
elif grant_type == "client_credentials" and (client_id is None):
if basic_creds:
client_id = basic_creds.username
client_secret = basic_creds.password
Expand Down
60 changes: 24 additions & 36 deletions nmdc_runtime/api/endpoints/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from datetime import timedelta

import pymongo.database
import requests
from fastapi import Depends, APIRouter, HTTPException, status
from fastapi.openapi.docs import get_swagger_ui_html
from jose import jws, JWTError
from starlette.requests import Request
from starlette.responses import HTMLResponse, RedirectResponse
Expand All @@ -16,6 +18,7 @@
ORCID_JWK,
ORCID_JWS_VERITY_ALGORITHM,
credentials_exception,
ORCID_CLIENT_SECRET,
)
from nmdc_runtime.api.core.auth import get_password_hash
from nmdc_runtime.api.core.util import generate_secret
Expand All @@ -32,43 +35,28 @@
router = APIRouter()


@router.get("/orcid_authorize")
async def orcid_authorize():
"""NOTE: You want to load /orcid_authorize directly in your web browser to initiate the login redirect flow."""
return RedirectResponse(
f"https://orcid.org/oauth/authorize?client_id={ORCID_CLIENT_ID}"
"&response_type=token&scope=openid&"
f"redirect_uri={BASE_URL_EXTERNAL}/orcid_token"
)


@router.get("/orcid_token")
async def redirect_uri_for_orcid_token(req: Request):
"""
Returns a web page that will display a user's orcid jwt token for copy/paste.
This route is loaded by orcid.org after a successful orcid user login.
"""
return HTMLResponse(
"""
<head>
<script>
function getFragmentParameterByName(name) {
name = name.replace(/[\[]/, "\\[").replace(/[\]]/, "\\]");
var regex = new RegExp("[\\#&]" + name + "=([^&#]*)"),
results = regex.exec(window.location.hash);
return results === null ? "" : decodeURIComponent(results[1].replace(/\+/g, " "));
}
</script>
</head>
<body>
<main id="token"></main>
</body>
<script>
document.getElementById("token").innerHTML = getFragmentParameterByName("id_token")
</script>
"""
@router.get("/orcid_code", response_class=RedirectResponse)
async def receive_orcid_code(request: Request, code: str, state: str | None = None):
rv = requests.post(
"https://orcid.org/oauth/token",
data=(
f"client_id={ORCID_CLIENT_ID}&client_secret={ORCID_CLIENT_SECRET}&"
f"grant_type=authorization_code&code={code}&redirect_uri={BASE_URL_EXTERNAL}/orcid_code"
),
headers={
"Content-type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
)
token_response = rv.json()
response = RedirectResponse(state or request.url_for("custom_swagger_ui_html"))
for key in ["user_orcid", "user_name", "user_id_token"]:
response.set_cookie(
key=key,
value=token_response[key.replace("user_", "")],
max_age=2592000,
)
return response


@router.post("/token", response_model=Token)
Expand Down
77 changes: 72 additions & 5 deletions nmdc_runtime/api/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import os
import re
from contextlib import asynccontextmanager
from importlib import import_module
from importlib.metadata import version
from typing import Annotated

import fastapi
import requests
import uvicorn
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter, FastAPI, Cookie
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.staticfiles import StaticFiles
from setuptools_scm import get_version
from starlette import status
from starlette.responses import RedirectResponse
from starlette.responses import RedirectResponse, HTMLResponse

from nmdc_runtime.api.analytics import Analytics
from nmdc_runtime.util import all_docs_have_unique_id, ensure_unique_id_indexes
from nmdc_runtime.api.core.auth import get_password_hash
from nmdc_runtime.util import (
ensure_unique_id_indexes,
REPO_ROOT_DIR,
)
from nmdc_runtime.api.core.auth import get_password_hash, ORCID_CLIENT_ID
from nmdc_runtime.api.db.mongo import (
get_mongo_db,
)
Expand Down Expand Up @@ -356,10 +364,15 @@ async def get_versions():
"\n\n"
"Dependency versions:\n\n"
f'nmdc-schema={version("nmdc_schema")}\n\n'
"<a href='https://microbiomedata.github.io/nmdc-runtime/'>Documentation</a>"
"<a href='https://microbiomedata.github.io/nmdc-runtime/'>Documentation</a>\n\n"
'<img src="/static/ORCIDiD_icon128x128.png" height="18" width="18"/> '
f'<a href="https://orcid.org/oauth/authorize?client_id={ORCID_CLIENT_ID}'
"&response_type=code&scope=openid&"
f'redirect_uri={BASE_URL_EXTERNAL}/orcid_code">Login with ORCiD</a>'
),
openapi_tags=tags_metadata,
lifespan=lifespan,
docs_url=None,
)
app.include_router(api_router)

Expand All @@ -372,6 +385,60 @@ async def get_versions():
allow_headers=["*"],
)
app.add_middleware(Analytics)
app.mount(
"/static",
StaticFiles(directory=REPO_ROOT_DIR.joinpath("nmdc_runtime/static/")),
name="static",
)


@app.get("/docs", include_in_schema=False)
def custom_swagger_ui_html(
user_id_token: Annotated[str | None, Cookie()] = None,
):
access_token = None
if user_id_token:
# get bearer token
rv = requests.post(
url=f"{BASE_URL_EXTERNAL}/token",
data={
"client_id": user_id_token,
"client_secret": "",
"grant_type": "client_credentials",
},
headers={
"Content-type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
)
if rv.status_code != 200:
rv.reason = rv.text
rv.raise_for_status()
access_token = rv.json()["access_token"]

swagger_ui_parameters = {"withCredentials": True}
if access_token is not None:
swagger_ui_parameters.update(
{
"onComplete": f"""<unquote-safe>() => {{ ui.preauthorizeApiKey(<double-quote>bearerAuth</double-quote>, <double-quote>{access_token}</double-quote>) }}</unquote-safe>""",
}
)
response = get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title,
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
swagger_js_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
swagger_css_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
swagger_ui_parameters=swagger_ui_parameters,
)
content = (
response.body.decode()
.replace('"<unquote-safe>', "")
.replace('</unquote-safe>"', "")
.replace("<double-quote>", '"')
.replace("</double-quote>", '"')
)
return HTMLResponse(content=content)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions nmdc_runtime/api/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
oauth2_scheme,
credentials_exception,
TokenData,
bearer_scheme,
)
from nmdc_runtime.api.db.mongo import get_mongo_db

Expand Down Expand Up @@ -49,6 +50,7 @@ def authenticate_user(mdb, username: str, password: str):

async def get_current_user(
token: str = Depends(oauth2_scheme),
bearer_credentials: str = Depends(bearer_scheme),
mdb: pymongo.database.Database = Depends(get_mongo_db),
) -> UserInDB:
if mdb.invalidated_tokens.find_one({"_id": token}):
Expand Down
Binary file added nmdc_runtime/static/ORCIDiD_icon128x128.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit d989ebd

Please sign in to comment.