diff --git a/src/millipds/auth_bearer.py b/src/millipds/auth_bearer.py index 0923c0e..d561379 100644 --- a/src/millipds/auth_bearer.py +++ b/src/millipds/auth_bearer.py @@ -46,7 +46,7 @@ async def authentication_handler(request: web.Request, *args, **kwargs): algorithms=["HS256"], audience=db.config["pds_did"], options={ - "require": ["exp", "iat", "scope"], # consider iat? + "require": ["exp", "iat", "scope", "jti", "sub"], "verify_exp": True, "verify_iat": True, "strict_aud": True, # may be unnecessary @@ -55,6 +55,14 @@ async def authentication_handler(request: web.Request, *args, **kwargs): except jwt.exceptions.PyJWTError: raise web.HTTPUnauthorized(text="invalid jwt") + revoked = db.con.execute( + "SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?", + (payload["sub"], payload["jti"]), + ).fetchone()[0] + + if revoked: + raise web.HTTPUnauthorized(text="revoked token") + # if we reached this far, the payload must've been signed by us if payload.get("scope") != "com.atproto.access": raise web.HTTPUnauthorized(text="invalid jwt scope") @@ -81,7 +89,7 @@ async def authentication_handler(request: web.Request, *args, **kwargs): algorithms=[alg], audience=db.config["pds_did"], options={ - "require": ["exp", "iat", "lxm"], + "require": ["exp", "iat", "lxm", "jti", "iss"], "verify_exp": True, "verify_iat": True, "strict_aud": True, # may be unnecessary @@ -90,6 +98,14 @@ async def authentication_handler(request: web.Request, *args, **kwargs): except jwt.exceptions.PyJWTError: raise web.HTTPUnauthorized(text="invalid jwt") + revoked = db.con.execute( + "SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?", + (payload["iss"], payload["jti"]), + ).fetchone()[0] + + if revoked: + raise web.HTTPUnauthorized(text="revoked token") + request_lxm = request.path.rpartition("/")[2].partition("?")[0] if request_lxm != payload.get("lxm"): raise web.HTTPUnauthorized(text="invalid jwt: bad lxm") diff --git a/src/millipds/database.py b/src/millipds/database.py index 05c9b91..0b742a2 100644 --- a/src/millipds/database.py +++ b/src/millipds/database.py @@ -245,6 +245,19 @@ def _init_tables(self): """ ) + # this is only for the tokens *we* issue, dpop jti will be tracked separately + # there's no point remembering that an expired token was revoked, and we'll garbage-collect these periodically + self.con.execute( + """ + CREATE TABLE revoked_token( + did TEXT NOT NULL, + jti TEXT NOT NULL, + expires_at INTEGER NOT NULL, + PRIMARY KEY (did, jti) + ) STRICT, WITHOUT ROWID + """ + ) + def update_config( self, pds_pfx: Optional[str] = None, diff --git a/src/millipds/service.py b/src/millipds/service.py index e2592de..3160fe4 100644 --- a/src/millipds/service.py +++ b/src/millipds/service.py @@ -302,7 +302,7 @@ async def server_get_service_auth(request: web.Request): "lxm": lxm, "exp": exp, "iat": now, - "jti": str(uuid.uuid4()) + "jti": str(uuid.uuid4()), }, signing_key, algorithm=crypto.jwt_signature_alg_for_pem(signing_key),