diff --git a/auth/jwt/token_test.go b/auth/jwt/token_test.go index dd17de7e078..0252f231776 100644 --- a/auth/jwt/token_test.go +++ b/auth/jwt/token_test.go @@ -201,18 +201,6 @@ func TestParseOAuthToken(t *testing.T) { refreshErr: nil, err: nil, }, - { - desc: "parse invalid key but refreshed", - token: validKey, - issuedToken: "", - key: validKey, - validateErr: svcerr.ErrAuthentication, - refreshToken: oauth2.Token{ - AccessToken: strings.Repeat("a", 10), - }, - refreshErr: nil, - err: nil, - }, { desc: "parse invalid key but not refreshed", token: validKey, @@ -225,22 +213,22 @@ func TestParseOAuthToken(t *testing.T) { }, { desc: "parse invalid key with different provider", - issuedToken: invalidOauthToken(t, invalidKey, "different", "a", nil), + issuedToken: invalidOauthToken(t, invalidKey, "different"), err: svcerr.ErrAuthentication, }, { desc: "parse invalid key with invalid access token", - issuedToken: invalidOauthToken(t, invalidKey, "invalid", 123, "b"), + issuedToken: invalidOauthToken(t, invalidKey, "invalid"), err: svcerr.ErrAuthentication, }, { desc: "parse invalid key with invalid refresh token", - issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", 123), + issuedToken: invalidOauthToken(t, invalidKey, "invalid"), err: svcerr.ErrAuthentication, }, { desc: "parse invalid key with invalid provider", - issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", nil), + issuedToken: invalidOauthToken(t, invalidKey, "invalid"), err: svcerr.ErrAuthentication, }, } @@ -294,7 +282,7 @@ func oauthKey(t *testing.T) auth.Key { } } -func invalidOauthToken(t *testing.T, key auth.Key, provider, accessToken, refreshToken interface{}) string { +func invalidOauthToken(t *testing.T, key auth.Key, provider interface{}) string { builder := jwt.NewBuilder() builder. Issuer(issuerName). @@ -306,12 +294,7 @@ func invalidOauthToken(t *testing.T, key auth.Key, provider, accessToken, refres builder.Claim(domainField, key.Domain) if provider != nil { builder.Claim("oauth_provider", provider) - if accessToken != nil { - builder.Claim(provider.(string), map[string]interface{}{"access_token": accessToken}) - } - if refreshToken != nil { - builder.Claim(provider.(string), map[string]interface{}{"refresh_token": refreshToken}) - } + builder.Claim(provider.(string), key.OAuth) } if key.ID != "" { builder.JwtID(key.ID) diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go index 66cd9b332ff..055e75ccfd9 100644 --- a/auth/jwt/tokenizer.go +++ b/auth/jwt/tokenizer.go @@ -84,18 +84,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) { return "", errors.Wrap(svcerr.ErrAuthentication, errInvalidProvider) } builder.Claim(oauthProviderField, provider.Name()) - - // when issuing an access token, the refresh token is not present and vice versa - if key.OAuth.AccessToken != "" && key.OAuth.RefreshToken == "" { - builder.Claim(provider.Name(), map[string]interface{}{ - oauthAccessTokenField: key.OAuth.AccessToken, - }) - } - if key.OAuth.AccessToken == "" && key.OAuth.RefreshToken != "" { - builder.Claim(provider.Name(), map[string]interface{}{ - oauthRefreshTokenField: key.OAuth.RefreshToken, - }) - } + builder.Claim(provider.Name(), key.OAuth) } if key.ID != "" { @@ -188,26 +177,27 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) { func parseOAuthToken(ctx context.Context, provider oauth2.Provider, token jwt.Token, key auth.Key) (auth.Key, error) { oauthToken, ok := token.Get(provider.Name()) if ok { - claims, ok := oauthToken.(map[string]interface{}) - if !ok { - return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid claims for %s token", provider.Name())) - } - - // access token and refresh token are not mandatory either of them can be present - accessToken, _ := claims[oauthAccessTokenField].(string) - refreshToken, _ := claims[oauthRefreshTokenField].(string) + var claims auth.OAuthToken + claims.FromInterface(oauthToken) - switch provider.Validate(ctx, accessToken) { - case nil: - key.OAuth.AccessToken = accessToken - key.OAuth.RefreshToken = refreshToken - default: - token, err := provider.Refresh(ctx, refreshToken) - if err != nil { + switch key.Type { + case auth.AccessKey: + if err := provider.Validate(ctx, claims.AccessToken); err != nil { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) } - key.OAuth.AccessToken = token.AccessToken - key.OAuth.RefreshToken = token.RefreshToken + key.OAuth.AccessToken = claims.AccessToken + case auth.RefreshKey: + if err := provider.Validate(ctx, claims.RefreshToken); err != nil { + token, err := provider.Refresh(ctx, claims.RefreshToken) + if err != nil { + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + key.OAuth.RefreshToken = token.RefreshToken + key.OAuth.AccessToken = token.AccessToken + + return key, nil + } + key.OAuth.RefreshToken = claims.RefreshToken } return key, nil diff --git a/auth/keys.go b/auth/keys.go index b5801b04439..6c01baa9f06 100644 --- a/auth/keys.go +++ b/auth/keys.go @@ -65,6 +65,24 @@ type OAuthToken struct { RefreshToken string `json:"refresh_token,omitempty"` } +func (o *OAuthToken) FromInterface(i interface{}) { + m, ok := i.(map[string]interface{}) + if ok { + provider, ok := m["provider"].(string) + if ok { + o.Provider = provider + } + accessToken, ok := m["access_token"].(string) + if ok { + o.AccessToken = accessToken + } + refreshToken, ok := m["refresh_token"].(string) + if ok { + o.RefreshToken = refreshToken + } + } +} + // Key represents API key. type Key struct { ID string `json:"id,omitempty"`