Skip to content

Commit

Permalink
feat(auth/jwt): improve token parsing and issuance
Browse files Browse the repository at this point in the history
Refactor code to enhance parsing of OAuth tokens, validate them with the provider and refresh if needed. Handle errors related to token issuance and update user's tokens.
When issuing access token we don't need OAuth2.0 refresh token embedded inside it and vice versa.

Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com>
  • Loading branch information
rodneyosodo committed Mar 6, 2024
1 parent 42d433a commit 20e15f0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 24 deletions.
12 changes: 5 additions & 7 deletions auth/jwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ func TestParseOAuthToken(t *testing.T) {
key: validKey,
validateErr: svcerr.ErrAuthentication,
refreshToken: oauth2.Token{
AccessToken: strings.Repeat("a", 10),
RefreshToken: strings.Repeat("b", 10),
AccessToken: strings.Repeat("a", 10),
},
refreshErr: nil,
err: nil,
Expand All @@ -226,7 +225,7 @@ func TestParseOAuthToken(t *testing.T) {
},
{
desc: "parse invalid key with different provider",
issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", "b"),
issuedToken: invalidOauthToken(t, invalidKey, "different", "a", nil),
err: svcerr.ErrAuthentication,
},
{
Expand All @@ -241,7 +240,7 @@ func TestParseOAuthToken(t *testing.T) {
},
{
desc: "parse invalid key with invalid provider",
issuedToken: invalidOauthToken(t, invalidKey, "test", "a", "b"),
issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", nil),
err: svcerr.ErrAuthentication,
},
}
Expand Down Expand Up @@ -289,9 +288,8 @@ func oauthKey(t *testing.T) auth.Key {
IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute).Round(time.Second),
OAuth: auth.OAuthToken{
Provider: "test",
AccessToken: strings.Repeat("a", 10),
RefreshToken: strings.Repeat("b", 10),
Provider: "test",
AccessToken: strings.Repeat("a", 10),
},
}
}
Expand Down
33 changes: 17 additions & 16 deletions auth/jwt/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,18 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
return "", errors.Wrap(svcerr.ErrAuthentication, errInvalidProvider)
}
builder.Claim(oauthProviderField, provider.Name())
builder.Claim(provider.Name(), map[string]interface{}{
oauthAccessTokenField: key.OAuth.AccessToken,
oauthRefreshTokenField: key.OAuth.RefreshToken,
})

// when issuing an access token, the refresh token is not present and vice versa
if key.OAuth.AccessToken != "" && key.OAuth.RefreshToken == "" {
builder.Claim(provider.Name(), map[string]interface{}{
oauthAccessTokenField: key.OAuth.AccessToken,
})
}
if key.OAuth.AccessToken == "" && key.OAuth.RefreshToken != "" {
builder.Claim(provider.Name(), map[string]interface{}{
oauthRefreshTokenField: key.OAuth.RefreshToken,
})
}
}

if key.ID != "" {
Expand Down Expand Up @@ -184,31 +192,24 @@ func parseOAuthToken(ctx context.Context, provider oauth2.Provider, token jwt.To
if !ok {
return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid claims for %s token", provider.Name()))
}
accessToken, ok := claims[oauthAccessTokenField].(string)
if !ok {
return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid access token claim for %s token", provider.Name()))
}
refreshToken, ok := claims[oauthRefreshTokenField].(string)
if !ok {
return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid refresh token claim for %s token", provider.Name()))
}

// access token and refresh token are not mandatory either of them can be present
accessToken, _ := claims[oauthAccessTokenField].(string)
refreshToken, _ := claims[oauthRefreshTokenField].(string)

switch provider.Validate(ctx, accessToken) {
case nil:
key.OAuth.AccessToken = accessToken
key.OAuth.RefreshToken = refreshToken
default:
token, err := provider.Refresh(ctx, refreshToken)
if err != nil {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
key.OAuth.AccessToken = token.AccessToken
key.OAuth.RefreshToken = token.RefreshToken

return key, nil
}

key.OAuth.RefreshToken = refreshToken

return key, nil
}

Expand Down
10 changes: 9 additions & 1 deletion auth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,15 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) {
return Token{}, errors.Wrap(svcerr.ErrAuthorization, err)
}

oauthRefresh := key.OAuth.RefreshToken
key.OAuth.RefreshToken = ""
access, err := svc.tokenizer.Issue(key)
if err != nil {
return Token{}, errors.Wrap(errIssueTmp, err)
}

key.OAuth.AccessToken = ""
key.OAuth.RefreshToken = oauthRefresh
key.ExpiresAt = time.Now().Add(svc.refreshDuration)
key.Type = RefreshKey
refresh, err := svc.tokenizer.Issue(key)
Expand Down Expand Up @@ -430,7 +435,7 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token

key.OAuth.Provider = k.OAuth.Provider
key.OAuth.AccessToken = k.OAuth.AccessToken
key.OAuth.RefreshToken = k.OAuth.RefreshToken
key.OAuth.RefreshToken = ""

key.Subject, err = svc.checkUserDomain(ctx, key)
if err != nil {
Expand All @@ -442,6 +447,9 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
if err != nil {
return Token{}, errors.Wrap(errIssueTmp, err)
}

key.OAuth.AccessToken = ""
key.OAuth.RefreshToken = k.OAuth.RefreshToken
key.ExpiresAt = time.Now().Add(svc.refreshDuration)
key.Type = RefreshKey
refresh, err := svc.tokenizer.Issue(key)
Expand Down

0 comments on commit 20e15f0

Please sign in to comment.