Skip to content

Commit

Permalink
Merge pull request #38 from DavidBuchanan314/token-revocation
Browse files Browse the repository at this point in the history
refreshSession, deleteSession
  • Loading branch information
DavidBuchanan314 authored Jan 2, 2025
2 parents 9d467bd + 5863b8c commit 0b061d0
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 74 deletions.
10 changes: 8 additions & 2 deletions migration_scripts/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from millipds import static_config

with apsw.Connection(static_config.MAIN_DB_PATH) as con:

def migrate(con):
version_now, *_ = con.execute("SELECT db_version FROM config").fetchone()

assert version_now == 1
Expand Down Expand Up @@ -36,4 +37,9 @@

con.execute("UPDATE config SET db_version=2")

print("v1 -> v2 Migration successful")

if __name__ == "__main__":
with apsw.Connection(static_config.MAIN_DB_PATH) as con:
migrate(con)

print("v1 -> v2 Migration successful")
34 changes: 34 additions & 0 deletions migration_scripts/v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# TODO: some smarter way of handling migrations

import apsw
import apsw.bestpractice

apsw.bestpractice.apply(apsw.bestpractice.recommended)

from millipds import static_config


def migrate(con: apsw.Connection):
version_now, *_ = con.execute("SELECT db_version FROM config").fetchone()

assert version_now == 2

con.execute(
"""
CREATE TABLE revoked_token(
did TEXT NOT NULL,
jti TEXT NOT NULL,
expires_at INTEGER NOT NULL,
PRIMARY KEY (did, jti)
) STRICT, WITHOUT ROWID
"""
)

con.execute("UPDATE config SET db_version=3")


if __name__ == "__main__":
with apsw.Connection(static_config.MAIN_DB_PATH) as con:
migrate(con)

print("v2 -> v3 Migration successful")
12 changes: 7 additions & 5 deletions src/millipds/appview_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ async def service_proxy(request: web.Request, service: Optional[str] = None):
)
if did_doc is None:
return web.HTTPInternalServerError(
f"unable to resolve service {service!r}"
text=f"unable to resolve service {service!r}"
)
for service in did_doc.get("service", []):
if service.get("id") == fragment:
service_route = service["serviceEndpoint"]
for service_info in did_doc.get("service", []):
if service_info.get("id") == fragment:
service_route = service_info["serviceEndpoint"]
break
else:
return web.HTTPBadRequest(f"unable to resolve service {service!r}")
return web.HTTPBadRequest(
text=f"unable to resolve service {service!r}"
)
else: # fall thru to assuming bsky appview
service_did = db.config["bsky_appview_did"]
service_route = db.config["bsky_appview_pfx"]
Expand Down
75 changes: 50 additions & 25 deletions src/millipds/auth_bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,44 @@
routes = web.RouteTableDef()


def verify_symmetric_token(
request: web.Request, token: str, expected_scope: str
) -> dict:
db = get_db(request)
try:
payload: dict = jwt.decode(
jwt=token,
key=db.config["jwt_access_secret"],
algorithms=["HS256"],
audience=db.config["pds_did"],
options={
"require": ["exp", "iat", "scope", "jti", "sub"],
"verify_exp": True,
"verify_iat": True,
"strict_aud": True, # may be unnecessary
},
)
except jwt.exceptions.PyJWTError:
raise web.HTTPUnauthorized(text="invalid jwt")

revoked = db.con.execute(
"SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?",
(payload["sub"], payload["jti"]),
).fetchone()[0]

if revoked:
raise web.HTTPUnauthorized(text="revoked token")

# if we reached this far, the payload must've been signed by us
if payload.get("scope") != expected_scope:
raise web.HTTPUnauthorized(text="invalid jwt scope")

if not payload.get("sub", "").startswith("did:"):
raise web.HTTPUnauthorized(text="invalid jwt: invalid subject")

return payload


def authenticated(handler):
"""
There are three types of auth:
Expand Down Expand Up @@ -39,30 +77,9 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
)
# logger.info(unverified)
if unverified["header"]["alg"] == "HS256": # symmetric secret
try:
payload: dict = jwt.decode(
jwt=token,
key=db.config["jwt_access_secret"],
algorithms=["HS256"],
audience=db.config["pds_did"],
options={
"require": ["exp", "iat", "scope"], # consider iat?
"verify_exp": True,
"verify_iat": True,
"strict_aud": True, # may be unnecessary
},
)
except jwt.exceptions.PyJWTError:
raise web.HTTPUnauthorized(text="invalid jwt")

# if we reached this far, the payload must've been signed by us
if payload.get("scope") != "com.atproto.access":
raise web.HTTPUnauthorized(text="invalid jwt scope")

subject: str = payload.get("sub", "")
if not subject.startswith("did:"):
raise web.HTTPUnauthorized(text="invalid jwt: invalid subject")
request["authed_did"] = subject
request["authed_did"] = verify_symmetric_token(
request, token, "com.atproto.access"
)["sub"]
else: # asymmetric service auth (scoped to a specific lxm)
did: str = unverified["payload"]["iss"]
if not did.startswith("did:"):
Expand All @@ -81,7 +98,7 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
algorithms=[alg],
audience=db.config["pds_did"],
options={
"require": ["exp", "iat", "lxm"],
"require": ["exp", "iat", "lxm", "jti", "iss"],
"verify_exp": True,
"verify_iat": True,
"strict_aud": True, # may be unnecessary
Expand All @@ -90,6 +107,14 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
except jwt.exceptions.PyJWTError:
raise web.HTTPUnauthorized(text="invalid jwt")

revoked = db.con.execute(
"SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?",
(payload["iss"], payload["jti"]),
).fetchone()[0]

if revoked:
raise web.HTTPUnauthorized(text="revoked token")

request_lxm = request.path.rpartition("/")[2].partition("?")[0]
if request_lxm != payload.get("lxm"):
raise web.HTTPUnauthorized(text="invalid jwt: bad lxm")
Expand Down
13 changes: 13 additions & 0 deletions src/millipds/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,19 @@ def _init_tables(self):
"""
)

# this is only for the tokens *we* issue, dpop jti will be tracked separately
# there's no point remembering that an expired token was revoked, and we'll garbage-collect these periodically
self.con.execute(
"""
CREATE TABLE revoked_token(
did TEXT NOT NULL,
jti TEXT NOT NULL,
expires_at INTEGER NOT NULL,
PRIMARY KEY (did, jti)
) STRICT, WITHOUT ROWID
"""
)

def update_config(
self,
pds_pfx: Optional[str] = None,
Expand Down
130 changes: 89 additions & 41 deletions src/millipds/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from . import crypto
from . import util
from .appview_proxy import service_proxy
from .auth_bearer import authenticated
from .auth_bearer import authenticated, verify_symmetric_token
from .app_util import *
from .did import DIDResolver

Expand Down Expand Up @@ -203,6 +203,53 @@ async def server_describe_server(request: web.Request):
)


def session_info(request: web.Request) -> dict:
return {
"handle": get_db(request).handle_by_did(request["authed_did"]),
"did": request["authed_did"],
"email": "tfw_no@email.invalid", # this and below are just here for testing lol
"emailConfirmed": True,
# "didDoc": {}, # iiuc this is only used for entryway usecase?
}


def generate_session_tokens(request: web.Request) -> dict:
db = get_db(request)
unix_seconds_now = int(time.time())
# use the same jti for both tokens, so revoking one revokes both
jti = str(uuid.uuid4())
access_jwt = jwt.encode(
{
"scope": "com.atproto.access",
"aud": db.config["pds_did"],
"sub": request["authed_did"],
"iat": unix_seconds_now,
"exp": unix_seconds_now + static_config.ACCESS_EXP,
"jti": jti,
},
db.config["jwt_access_secret"],
"HS256",
)

refresh_jwt = jwt.encode(
{
"scope": "com.atproto.refresh",
"aud": db.config["pds_did"],
"sub": request["authed_did"],
"iat": unix_seconds_now,
"exp": unix_seconds_now + static_config.REFRESH_EXP,
"jti": jti,
},
db.config["jwt_access_secret"],
"HS256",
)

return {
"accessJwt": access_jwt,
"refreshJwt": refresh_jwt,
}


# TODO: ratelimit this!!!
@routes.post("/xrpc/com.atproto.server.createSession")
async def server_create_session(request: web.Request):
Expand All @@ -228,44 +275,53 @@ async def server_create_session(request: web.Request):
except ValueError:
raise web.HTTPUnauthorized(text="incorrect identifier or password")

# prepare access tokens
unix_seconds_now = int(time.time())
access_jwt = jwt.encode(
{
"scope": "com.atproto.access",
"aud": db.config["pds_did"],
"sub": did,
"iat": unix_seconds_now,
"exp": unix_seconds_now + 60 * 60 * 24, # 24h
"jti": str(uuid.uuid4()),
},
db.config["jwt_access_secret"],
"HS256",
# both generate_session_tokens and session_info need this
request["authed_did"] = did

return web.json_response(
session_info(request) | generate_session_tokens(request)
)

refresh_jwt = jwt.encode(
{
"scope": "com.atproto.refresh",
"aud": db.config["pds_did"],
"sub": did,
"iat": unix_seconds_now,
"exp": unix_seconds_now + 60 * 60 * 24 * 90, # 90 days!
"jti": str(uuid.uuid4()),
},
db.config["jwt_access_secret"],
"HS256",

@routes.post("/xrpc/com.atproto.server.refreshSession")
async def server_refresh_session(request: web.Request):
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise web.HTTPUnauthorized(text="invalid auth type")
token = auth.removeprefix("Bearer ")
token_payload = verify_symmetric_token(
request, token, "com.atproto.refresh"
)
request["authed_did"] = token_payload["sub"]

get_db(request).con.execute(
"INSERT INTO revoked_token (did, jti, expires_at) VALUES (?, ?, ?)",
(token_payload["sub"], token_payload["jti"], token_payload["exp"]),
)
return web.json_response(
{
"did": did,
"handle": handle,
"accessJwt": access_jwt,
"refreshJwt": refresh_jwt,
}
session_info(request) | generate_session_tokens(request)
)


# NOTE: deleteSession requires refresh token as auth, not access token
@routes.post("/xrpc/com.atproto.server.deleteSession")
async def server_delete_session(request: web.Request):
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise web.HTTPUnauthorized(text="invalid auth type")
token = auth.removeprefix("Bearer ")
token_payload = verify_symmetric_token(
request, token, "com.atproto.refresh"
)

get_db(request).con.execute(
"INSERT INTO revoked_token (did, jti, expires_at) VALUES (?, ?, ?)",
(token_payload["sub"], token_payload["jti"], token_payload["exp"]),
)

return web.Response()


@routes.get("/xrpc/com.atproto.server.getServiceAuth")
@authenticated
async def server_get_service_auth(request: web.Request):
Expand Down Expand Up @@ -302,7 +358,7 @@ async def server_get_service_auth(request: web.Request):
"lxm": lxm,
"exp": exp,
"iat": now,
"jti": str(uuid.uuid4())
"jti": str(uuid.uuid4()),
},
signing_key,
algorithm=crypto.jwt_signature_alg_for_pem(signing_key),
Expand Down Expand Up @@ -381,15 +437,7 @@ async def identity_update_handle(request: web.Request):
@routes.get("/xrpc/com.atproto.server.getSession")
@authenticated
async def server_get_session(request: web.Request):
return web.json_response(
{
"handle": get_db(request).handle_by_did(request["authed_did"]),
"did": request["authed_did"],
"email": "tfw_no@email.invalid", # this and below are just here for testing lol
"emailConfirmed": True,
# "didDoc": {}, # iiuc this is only used for entryway usecase?
}
)
return web.json_response(session_info(request))


def construct_app(
Expand Down
5 changes: 4 additions & 1 deletion src/millipds/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
GROUPNAME = "millipds-sock"

# this gets bumped if we make breaking changes to the db schema
MILLIPDS_DB_VERSION = 2
MILLIPDS_DB_VERSION = 3

ATPROTO_REPO_VERSION_3 = 3 # might get bumped if the atproto spec changes
CAR_VERSION_1 = 1
Expand All @@ -29,3 +29,6 @@
DID_CACHE_ERROR_TTL = 60 * 5 # 5 mins

PLC_DIRECTORY_HOST = "https://plc.directory"

ACCESS_EXP = 60 * 60 * 2 # 2 h
REFRESH_EXP = 60 * 60 * 24 * 90 # 90 days
Loading

0 comments on commit 0b061d0

Please sign in to comment.