Skip to content

Commit

Permalink
reject revoked auth tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Jan 2, 2025
1 parent 9d467bd commit a2cbc49
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
20 changes: 18 additions & 2 deletions src/millipds/auth_bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
algorithms=["HS256"],
audience=db.config["pds_did"],
options={
"require": ["exp", "iat", "scope"], # consider iat?
"require": ["exp", "iat", "scope", "jti", "sub"],
"verify_exp": True,
"verify_iat": True,
"strict_aud": True, # may be unnecessary
Expand All @@ -55,6 +55,14 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
except jwt.exceptions.PyJWTError:
raise web.HTTPUnauthorized(text="invalid jwt")

revoked = db.con.execute(
"SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?",
(payload["sub"], payload["jti"]),
).fetchone()[0]

if revoked:
raise web.HTTPUnauthorized(text="revoked token")

# if we reached this far, the payload must've been signed by us
if payload.get("scope") != "com.atproto.access":
raise web.HTTPUnauthorized(text="invalid jwt scope")
Expand All @@ -81,7 +89,7 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
algorithms=[alg],
audience=db.config["pds_did"],
options={
"require": ["exp", "iat", "lxm"],
"require": ["exp", "iat", "lxm", "jti", "iss"],
"verify_exp": True,
"verify_iat": True,
"strict_aud": True, # may be unnecessary
Expand All @@ -90,6 +98,14 @@ async def authentication_handler(request: web.Request, *args, **kwargs):
except jwt.exceptions.PyJWTError:
raise web.HTTPUnauthorized(text="invalid jwt")

revoked = db.con.execute(
"SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?",
(payload["iss"], payload["jti"]),
).fetchone()[0]

if revoked:
raise web.HTTPUnauthorized(text="revoked token")

request_lxm = request.path.rpartition("/")[2].partition("?")[0]
if request_lxm != payload.get("lxm"):
raise web.HTTPUnauthorized(text="invalid jwt: bad lxm")
Expand Down
13 changes: 13 additions & 0 deletions src/millipds/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,19 @@ def _init_tables(self):
"""
)

# this is only for the tokens *we* issue, dpop jti will be tracked separately
# there's no point remembering that an expired token was revoked, and we'll garbage-collect these periodically
self.con.execute(
"""
CREATE TABLE revoked_token(
did TEXT NOT NULL,
jti TEXT NOT NULL,
expires_at INTEGER NOT NULL,
PRIMARY KEY (did, jti)
) STRICT, WITHOUT ROWID
"""
)

def update_config(
self,
pds_pfx: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/millipds/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ async def server_get_service_auth(request: web.Request):
"lxm": lxm,
"exp": exp,
"iat": now,
"jti": str(uuid.uuid4())
"jti": str(uuid.uuid4()),
},
signing_key,
algorithm=crypto.jwt_signature_alg_for_pem(signing_key),
Expand Down

0 comments on commit a2cbc49

Please sign in to comment.