Skip to content

Commit ac07820

Browse files
Krismix1ryshoooo
andauthored
fix: Feature parity for a_decode_token and decode_token (#616)
* Consistency for token decoding * Mark as staticmethod * Helper function to convert key * Refactor key handling * Add tests for validate=False * Change test name * Fix failing test * Remove special case for str * Some docstring * docs: missing docstrings --------- Co-authored-by: Richard Nemeth <ryshoooo@gmail.com>
1 parent 3b946c3 commit ac07820

File tree

2 files changed

+112
-22
lines changed

2 files changed

+112
-22
lines changed

src/keycloak/keycloak_openid.py

+38-21
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class to handle authentication and token manipulation.
2828
"""
2929

3030
import json
31-
from typing import Optional
31+
from typing import Optional, Union
3232

3333
from jwcrypto import jwk, jwt
3434

@@ -581,6 +581,33 @@ def introspect(self, token, rpt=None, token_type_hint=None):
581581
)
582582
return raise_error_from_response(data_raw, KeycloakPostError)
583583

584+
@staticmethod
585+
def _verify_token(token, key: Union[jwk.JWK, jwk.JWKSet, None], **kwargs):
586+
"""Decode and optionally validate a token.
587+
588+
:param token: The token to verify
589+
:type token: str
590+
:param key: Which key should be used for validation.
591+
If not provided, the validation is not performed and the token is implicitly valid.
592+
:type key: Union[jwk.JWK, jwk.JWKSet, None]
593+
:param kwargs: Additional keyword arguments for jwcrypto's JWT object
594+
:type kwargs: dict
595+
:returns: Decoded token
596+
"""
597+
# keep the function free of IO
598+
# this way it can be used by `decode_token` and `a_decode_token`
599+
600+
if key is not None:
601+
leeway = kwargs.pop("leeway", 60)
602+
full_jwt = jwt.JWT(jwt=token, **kwargs)
603+
full_jwt.leeway = leeway
604+
full_jwt.validate(key)
605+
return jwt.json_decode(full_jwt.claims)
606+
else:
607+
full_jwt = jwt.JWT(jwt=token, **kwargs)
608+
full_jwt.token.objects["valid"] = True
609+
return json.loads(full_jwt.token.payload.decode("utf-8"))
610+
584611
def decode_token(self, token, validate: bool = True, **kwargs):
585612
"""Decode user token.
586613
@@ -603,26 +630,19 @@ def decode_token(self, token, validate: bool = True, **kwargs):
603630
:returns: Decoded token
604631
:rtype: dict
605632
"""
633+
key = kwargs.pop("key", None)
606634
if validate:
607-
if "key" not in kwargs:
635+
if key is None:
608636
key = (
609637
"-----BEGIN PUBLIC KEY-----\n"
610638
+ self.public_key()
611639
+ "\n-----END PUBLIC KEY-----"
612640
)
613641
key = jwk.JWK.from_pem(key.encode("utf-8"))
614-
kwargs["key"] = key
615-
616-
key = kwargs.pop("key")
617-
leeway = kwargs.pop("leeway", 60)
618-
full_jwt = jwt.JWT(jwt=token, **kwargs)
619-
full_jwt.leeway = leeway
620-
full_jwt.validate(key)
621-
return jwt.json_decode(full_jwt.claims)
622642
else:
623-
full_jwt = jwt.JWT(jwt=token, **kwargs)
624-
full_jwt.token.objects["valid"] = True
625-
return json.loads(full_jwt.token.payload.decode("utf-8"))
643+
key = None
644+
645+
return self._verify_token(token, key, **kwargs)
626646

627647
def load_authorization_config(self, path):
628648
"""Load Keycloak settings (authorization).
@@ -1273,22 +1293,19 @@ async def a_decode_token(self, token, validate: bool = True, **kwargs):
12731293
:returns: Decoded token
12741294
:rtype: dict
12751295
"""
1296+
key = kwargs.pop("key", None)
12761297
if validate:
1277-
if "key" not in kwargs:
1298+
if key is None:
12781299
key = (
12791300
"-----BEGIN PUBLIC KEY-----\n"
12801301
+ await self.a_public_key()
12811302
+ "\n-----END PUBLIC KEY-----"
12821303
)
12831304
key = jwk.JWK.from_pem(key.encode("utf-8"))
1284-
kwargs["key"] = key
1285-
1286-
full_jwt = jwt.JWT(jwt=token, **kwargs)
1287-
return jwt.json_decode(full_jwt.claims)
12881305
else:
1289-
full_jwt = jwt.JWT(jwt=token, **kwargs)
1290-
full_jwt.token.objects["valid"] = True
1291-
return json.loads(full_jwt.token.payload.decode("utf-8"))
1306+
key = None
1307+
1308+
return self._verify_token(token, key, **kwargs)
12921309

12931310
async def a_load_authorization_config(self, path):
12941311
"""Load Keycloak settings (authorization) asynchronously.

tests/test_keycloak_openid.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Tuple
55
from unittest import mock
66

7+
import jwcrypto.jwk
8+
import jwcrypto.jws
79
import pytest
810

911
from keycloak import KeycloakAdmin, KeycloakOpenID
@@ -317,6 +319,39 @@ def test_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
317319
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token
318320

319321

322+
def test_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
323+
"""Test decode token with an invalid token.
324+
325+
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
326+
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
327+
"""
328+
oid, username, password = oid_with_credentials
329+
token = oid.token(username=username, password=password)
330+
access_token = token["access_token"]
331+
decoded_access_token = oid.decode_token(token=access_token)
332+
333+
key = oid.public_key()
334+
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
335+
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))
336+
337+
invalid_access_token = access_token + "a"
338+
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
339+
decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=True)
340+
341+
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
342+
decoded_invalid_access_token = oid.decode_token(
343+
token=invalid_access_token, validate=True, key=key
344+
)
345+
346+
decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=False)
347+
assert decoded_access_token == decoded_invalid_access_token
348+
349+
decoded_invalid_access_token = oid.decode_token(
350+
token=invalid_access_token, validate=False, key=key
351+
)
352+
assert decoded_access_token == decoded_invalid_access_token
353+
354+
320355
def test_load_authorization_config(oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]):
321356
"""Test load authorization config.
322357
@@ -765,7 +800,7 @@ async def test_a_introspect(oid_with_credentials: Tuple[KeycloakOpenID, str, str
765800

766801
@pytest.mark.asyncio
767802
async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
768-
"""Test decode token.
803+
"""Test decode token asynchronously.
769804
770805
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
771806
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
@@ -781,6 +816,44 @@ async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, s
781816
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token
782817

783818

819+
@pytest.mark.asyncio
820+
async def test_a_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
821+
"""Test decode token asynchronously an invalid token.
822+
823+
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
824+
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
825+
"""
826+
oid, username, password = oid_with_credentials
827+
token = await oid.a_token(username=username, password=password)
828+
access_token = token["access_token"]
829+
decoded_access_token = await oid.a_decode_token(token=access_token)
830+
831+
key = await oid.a_public_key()
832+
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
833+
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))
834+
835+
invalid_access_token = access_token + "a"
836+
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
837+
decoded_invalid_access_token = await oid.a_decode_token(
838+
token=invalid_access_token, validate=True
839+
)
840+
841+
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
842+
decoded_invalid_access_token = await oid.a_decode_token(
843+
token=invalid_access_token, validate=True, key=key
844+
)
845+
846+
decoded_invalid_access_token = await oid.a_decode_token(
847+
token=invalid_access_token, validate=False
848+
)
849+
assert decoded_access_token == decoded_invalid_access_token
850+
851+
decoded_invalid_access_token = await oid.a_decode_token(
852+
token=invalid_access_token, validate=False, key=key
853+
)
854+
assert decoded_access_token == decoded_invalid_access_token
855+
856+
784857
@pytest.mark.asyncio
785858
async def test_a_load_authorization_config(
786859
oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]

0 commit comments

Comments
 (0)