diff --git a/auth/jwt/token_test.go b/auth/jwt/token_test.go index 9a15e0483d2..dd17de7e078 100644 --- a/auth/jwt/token_test.go +++ b/auth/jwt/token_test.go @@ -208,8 +208,7 @@ func TestParseOAuthToken(t *testing.T) { key: validKey, validateErr: svcerr.ErrAuthentication, refreshToken: oauth2.Token{ - AccessToken: strings.Repeat("a", 10), - RefreshToken: strings.Repeat("b", 10), + AccessToken: strings.Repeat("a", 10), }, refreshErr: nil, err: nil, @@ -226,7 +225,7 @@ func TestParseOAuthToken(t *testing.T) { }, { desc: "parse invalid key with different provider", - issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", "b"), + issuedToken: invalidOauthToken(t, invalidKey, "different", "a", nil), err: svcerr.ErrAuthentication, }, { @@ -241,7 +240,7 @@ func TestParseOAuthToken(t *testing.T) { }, { desc: "parse invalid key with invalid provider", - issuedToken: invalidOauthToken(t, invalidKey, "test", "a", "b"), + issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", nil), err: svcerr.ErrAuthentication, }, } @@ -289,9 +288,8 @@ func oauthKey(t *testing.T) auth.Key { IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second), ExpiresAt: time.Now().UTC().Add(10 * time.Minute).Round(time.Second), OAuth: auth.OAuthToken{ - Provider: "test", - AccessToken: strings.Repeat("a", 10), - RefreshToken: strings.Repeat("b", 10), + Provider: "test", + AccessToken: strings.Repeat("a", 10), }, } } diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go index 5627d0609a0..66cd9b332ff 100644 --- a/auth/jwt/tokenizer.go +++ b/auth/jwt/tokenizer.go @@ -84,10 +84,18 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) { return "", errors.Wrap(svcerr.ErrAuthentication, errInvalidProvider) } builder.Claim(oauthProviderField, provider.Name()) - builder.Claim(provider.Name(), map[string]interface{}{ - oauthAccessTokenField: key.OAuth.AccessToken, - oauthRefreshTokenField: key.OAuth.RefreshToken, - }) + + // when issuing an access token, the refresh token is not present and vice versa + if key.OAuth.AccessToken != "" && key.OAuth.RefreshToken == "" { + builder.Claim(provider.Name(), map[string]interface{}{ + oauthAccessTokenField: key.OAuth.AccessToken, + }) + } + if key.OAuth.AccessToken == "" && key.OAuth.RefreshToken != "" { + builder.Claim(provider.Name(), map[string]interface{}{ + oauthRefreshTokenField: key.OAuth.RefreshToken, + }) + } } if key.ID != "" { @@ -184,18 +192,15 @@ func parseOAuthToken(ctx context.Context, provider oauth2.Provider, token jwt.To if !ok { return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid claims for %s token", provider.Name())) } - accessToken, ok := claims[oauthAccessTokenField].(string) - if !ok { - return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid access token claim for %s token", provider.Name())) - } - refreshToken, ok := claims[oauthRefreshTokenField].(string) - if !ok { - return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid refresh token claim for %s token", provider.Name())) - } + + // access token and refresh token are not mandatory either of them can be present + accessToken, _ := claims[oauthAccessTokenField].(string) + refreshToken, _ := claims[oauthRefreshTokenField].(string) switch provider.Validate(ctx, accessToken) { case nil: key.OAuth.AccessToken = accessToken + key.OAuth.RefreshToken = refreshToken default: token, err := provider.Refresh(ctx, refreshToken) if err != nil { @@ -203,12 +208,8 @@ func parseOAuthToken(ctx context.Context, provider oauth2.Provider, token jwt.To } key.OAuth.AccessToken = token.AccessToken key.OAuth.RefreshToken = token.RefreshToken - - return key, nil } - key.OAuth.RefreshToken = refreshToken - return key, nil } diff --git a/auth/service.go b/auth/service.go index cdead702276..9b81d1d709d 100644 --- a/auth/service.go +++ b/auth/service.go @@ -381,10 +381,15 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) { return Token{}, errors.Wrap(svcerr.ErrAuthorization, err) } + oauthRefresh := key.OAuth.RefreshToken + key.OAuth.RefreshToken = "" access, err := svc.tokenizer.Issue(key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } + + key.OAuth.AccessToken = "" + key.OAuth.RefreshToken = oauthRefresh key.ExpiresAt = time.Now().Add(svc.refreshDuration) key.Type = RefreshKey refresh, err := svc.tokenizer.Issue(key) @@ -430,7 +435,7 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token key.OAuth.Provider = k.OAuth.Provider key.OAuth.AccessToken = k.OAuth.AccessToken - key.OAuth.RefreshToken = k.OAuth.RefreshToken + key.OAuth.RefreshToken = "" key.Subject, err = svc.checkUserDomain(ctx, key) if err != nil { @@ -442,6 +447,9 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } + + key.OAuth.AccessToken = "" + key.OAuth.RefreshToken = k.OAuth.RefreshToken key.ExpiresAt = time.Now().Add(svc.refreshDuration) key.Type = RefreshKey refresh, err := svc.tokenizer.Issue(key)