From 781a5dcefaf225956c4896e02155cc0ea537eca0 Mon Sep 17 00:00:00 2001 From: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com> Date: Wed, 6 Mar 2024 13:41:45 +0300 Subject: [PATCH] feat(auth/jwt): improve token parsing and issuance Refactor code to enhance parsing of OAuth tokens, validate them with the provider and refresh if needed. Handle errors related to token issuance and update user's tokens. When issuing access token we don't need OAuth2.0 refresh token embedded inside it and vice versa. Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com> --- auth/jwt/token_test.go | 12 +++++------- auth/jwt/tokenizer.go | 33 +++++++++++++++++---------------- auth/service.go | 10 +++++++++- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/auth/jwt/token_test.go b/auth/jwt/token_test.go index 9a15e0483d2..dd17de7e078 100644 --- a/auth/jwt/token_test.go +++ b/auth/jwt/token_test.go @@ -208,8 +208,7 @@ func TestParseOAuthToken(t *testing.T) { key: validKey, validateErr: svcerr.ErrAuthentication, refreshToken: oauth2.Token{ - AccessToken: strings.Repeat("a", 10), - RefreshToken: strings.Repeat("b", 10), + AccessToken: strings.Repeat("a", 10), }, refreshErr: nil, err: nil, @@ -226,7 +225,7 @@ func TestParseOAuthToken(t *testing.T) { }, { desc: "parse invalid key with different provider", - issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", "b"), + issuedToken: invalidOauthToken(t, invalidKey, "different", "a", nil), err: svcerr.ErrAuthentication, }, { @@ -241,7 +240,7 @@ func TestParseOAuthToken(t *testing.T) { }, { desc: "parse invalid key with invalid provider", - issuedToken: invalidOauthToken(t, invalidKey, "test", "a", "b"), + issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", nil), err: svcerr.ErrAuthentication, }, } @@ -289,9 +288,8 @@ func oauthKey(t *testing.T) auth.Key { IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second), ExpiresAt: time.Now().UTC().Add(10 * time.Minute).Round(time.Second), OAuth: auth.OAuthToken{ - Provider: "test", - AccessToken: strings.Repeat("a", 10), - RefreshToken: strings.Repeat("b", 10), + Provider: "test", + AccessToken: strings.Repeat("a", 10), }, } } diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go index 5627d0609a0..66cd9b332ff 100644 --- a/auth/jwt/tokenizer.go +++ b/auth/jwt/tokenizer.go @@ -84,10 +84,18 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) { return "", errors.Wrap(svcerr.ErrAuthentication, errInvalidProvider) } builder.Claim(oauthProviderField, provider.Name()) - builder.Claim(provider.Name(), map[string]interface{}{ - oauthAccessTokenField: key.OAuth.AccessToken, - oauthRefreshTokenField: key.OAuth.RefreshToken, - }) + + // when issuing an access token, the refresh token is not present and vice versa + if key.OAuth.AccessToken != "" && key.OAuth.RefreshToken == "" { + builder.Claim(provider.Name(), map[string]interface{}{ + oauthAccessTokenField: key.OAuth.AccessToken, + }) + } + if key.OAuth.AccessToken == "" && key.OAuth.RefreshToken != "" { + builder.Claim(provider.Name(), map[string]interface{}{ + oauthRefreshTokenField: key.OAuth.RefreshToken, + }) + } } if key.ID != "" { @@ -184,18 +192,15 @@ func parseOAuthToken(ctx context.Context, provider oauth2.Provider, token jwt.To if !ok { return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid claims for %s token", provider.Name())) } - accessToken, ok := claims[oauthAccessTokenField].(string) - if !ok { - return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid access token claim for %s token", provider.Name())) - } - refreshToken, ok := claims[oauthRefreshTokenField].(string) - if !ok { - return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid refresh token claim for %s token", provider.Name())) - } + + // access token and refresh token are not mandatory either of them can be present + accessToken, _ := claims[oauthAccessTokenField].(string) + refreshToken, _ := claims[oauthRefreshTokenField].(string) switch provider.Validate(ctx, accessToken) { case nil: key.OAuth.AccessToken = accessToken + key.OAuth.RefreshToken = refreshToken default: token, err := provider.Refresh(ctx, refreshToken) if err != nil { @@ -203,12 +208,8 @@ func parseOAuthToken(ctx context.Context, provider oauth2.Provider, token jwt.To } key.OAuth.AccessToken = token.AccessToken key.OAuth.RefreshToken = token.RefreshToken - - return key, nil } - key.OAuth.RefreshToken = refreshToken - return key, nil } diff --git a/auth/service.go b/auth/service.go index cdead702276..9b81d1d709d 100644 --- a/auth/service.go +++ b/auth/service.go @@ -381,10 +381,15 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) { return Token{}, errors.Wrap(svcerr.ErrAuthorization, err) } + oauthRefresh := key.OAuth.RefreshToken + key.OAuth.RefreshToken = "" access, err := svc.tokenizer.Issue(key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } + + key.OAuth.AccessToken = "" + key.OAuth.RefreshToken = oauthRefresh key.ExpiresAt = time.Now().Add(svc.refreshDuration) key.Type = RefreshKey refresh, err := svc.tokenizer.Issue(key) @@ -430,7 +435,7 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token key.OAuth.Provider = k.OAuth.Provider key.OAuth.AccessToken = k.OAuth.AccessToken - key.OAuth.RefreshToken = k.OAuth.RefreshToken + key.OAuth.RefreshToken = "" key.Subject, err = svc.checkUserDomain(ctx, key) if err != nil { @@ -442,6 +447,9 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } + + key.OAuth.AccessToken = "" + key.OAuth.RefreshToken = k.OAuth.RefreshToken key.ExpiresAt = time.Now().Add(svc.refreshDuration) key.Type = RefreshKey refresh, err := svc.tokenizer.Issue(key)