Skip to content

Commit

Permalink
Refactor OAuth token parsing and handling
Browse files Browse the repository at this point in the history
Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com>
  • Loading branch information
rodneyosodo committed Mar 7, 2024
1 parent 781a5dc commit 8a9c14c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 52 deletions.
29 changes: 6 additions & 23 deletions auth/jwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,6 @@ func TestParseOAuthToken(t *testing.T) {
refreshErr: nil,
err: nil,
},
{
desc: "parse invalid key but refreshed",
token: validKey,
issuedToken: "",
key: validKey,
validateErr: svcerr.ErrAuthentication,
refreshToken: oauth2.Token{
AccessToken: strings.Repeat("a", 10),
},
refreshErr: nil,
err: nil,
},
{
desc: "parse invalid key but not refreshed",
token: validKey,
Expand All @@ -225,22 +213,22 @@ func TestParseOAuthToken(t *testing.T) {
},
{
desc: "parse invalid key with different provider",
issuedToken: invalidOauthToken(t, invalidKey, "different", "a", nil),
issuedToken: invalidOauthToken(t, invalidKey, "different"),
err: svcerr.ErrAuthentication,
},
{
desc: "parse invalid key with invalid access token",
issuedToken: invalidOauthToken(t, invalidKey, "invalid", 123, "b"),
issuedToken: invalidOauthToken(t, invalidKey, "invalid"),
err: svcerr.ErrAuthentication,
},
{
desc: "parse invalid key with invalid refresh token",
issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", 123),
issuedToken: invalidOauthToken(t, invalidKey, "invalid"),
err: svcerr.ErrAuthentication,
},
{
desc: "parse invalid key with invalid provider",
issuedToken: invalidOauthToken(t, invalidKey, "invalid", "a", nil),
issuedToken: invalidOauthToken(t, invalidKey, "invalid"),
err: svcerr.ErrAuthentication,
},
}
Expand Down Expand Up @@ -294,7 +282,7 @@ func oauthKey(t *testing.T) auth.Key {
}
}

func invalidOauthToken(t *testing.T, key auth.Key, provider, accessToken, refreshToken interface{}) string {
func invalidOauthToken(t *testing.T, key auth.Key, provider interface{}) string {
builder := jwt.NewBuilder()
builder.
Issuer(issuerName).
Expand All @@ -306,12 +294,7 @@ func invalidOauthToken(t *testing.T, key auth.Key, provider, accessToken, refres
builder.Claim(domainField, key.Domain)
if provider != nil {
builder.Claim("oauth_provider", provider)
if accessToken != nil {
builder.Claim(provider.(string), map[string]interface{}{"access_token": accessToken})
}
if refreshToken != nil {
builder.Claim(provider.(string), map[string]interface{}{"refresh_token": refreshToken})
}
builder.Claim(provider.(string), key.OAuth)
}
if key.ID != "" {
builder.JwtID(key.ID)
Expand Down
48 changes: 19 additions & 29 deletions auth/jwt/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
return "", errors.Wrap(svcerr.ErrAuthentication, errInvalidProvider)
}
builder.Claim(oauthProviderField, provider.Name())

// when issuing an access token, the refresh token is not present and vice versa
if key.OAuth.AccessToken != "" && key.OAuth.RefreshToken == "" {
builder.Claim(provider.Name(), map[string]interface{}{
oauthAccessTokenField: key.OAuth.AccessToken,
})
}
if key.OAuth.AccessToken == "" && key.OAuth.RefreshToken != "" {
builder.Claim(provider.Name(), map[string]interface{}{
oauthRefreshTokenField: key.OAuth.RefreshToken,
})
}
builder.Claim(provider.Name(), key.OAuth)
}

if key.ID != "" {
Expand Down Expand Up @@ -188,26 +177,27 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) {
func parseOAuthToken(ctx context.Context, provider oauth2.Provider, token jwt.Token, key auth.Key) (auth.Key, error) {
oauthToken, ok := token.Get(provider.Name())
if ok {
claims, ok := oauthToken.(map[string]interface{})
if !ok {
return auth.Key{}, errors.Wrap(ErrParseToken, fmt.Errorf("invalid claims for %s token", provider.Name()))
}

// access token and refresh token are not mandatory either of them can be present
accessToken, _ := claims[oauthAccessTokenField].(string)
refreshToken, _ := claims[oauthRefreshTokenField].(string)
var claims auth.OAuthToken
claims.FromInterface(oauthToken)

switch provider.Validate(ctx, accessToken) {
case nil:
key.OAuth.AccessToken = accessToken
key.OAuth.RefreshToken = refreshToken
default:
token, err := provider.Refresh(ctx, refreshToken)
if err != nil {
switch key.Type {
case auth.AccessKey:
if err := provider.Validate(ctx, claims.AccessToken); err != nil {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
key.OAuth.AccessToken = token.AccessToken
key.OAuth.RefreshToken = token.RefreshToken
key.OAuth.AccessToken = claims.AccessToken
case auth.RefreshKey:
if err := provider.Validate(ctx, claims.RefreshToken); err != nil {
token, err := provider.Refresh(ctx, claims.RefreshToken)
if err != nil {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
key.OAuth.RefreshToken = token.RefreshToken
key.OAuth.AccessToken = token.AccessToken

return key, nil
}
key.OAuth.RefreshToken = claims.RefreshToken
}

return key, nil
Expand Down
18 changes: 18 additions & 0 deletions auth/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ type OAuthToken struct {
RefreshToken string `json:"refresh_token,omitempty"`
}

func (o *OAuthToken) FromInterface(i interface{}) {
m, ok := i.(map[string]interface{})
if ok {
provider, ok := m["provider"].(string)
if ok {
o.Provider = provider
}
accessToken, ok := m["access_token"].(string)
if ok {
o.AccessToken = accessToken
}
refreshToken, ok := m["refresh_token"].(string)
if ok {
o.RefreshToken = refreshToken
}
}
}

// Key represents API key.
type Key struct {
ID string `json:"id,omitempty"`
Expand Down

0 comments on commit 8a9c14c

Please sign in to comment.