diff --git a/cmd/auth.go b/cmd/auth.go index 914fa81..77fbaf7 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -2,6 +2,7 @@ package cmd import ( "errors" + "math" "syscall" "github.com/jsdelivr/globalping-cli/globalping" @@ -89,7 +90,7 @@ func (r *Root) RunAuthStatus(cmd *cobra.Command, args []string) error { res, err := r.client.TokenIntrospection("") if err != nil { e, ok := err.(*globalping.AuthorizeError) - if ok && e.ErrorType == "not_authorized" { + if ok && e.ErrorType == globalping.ErrTypeNotAuthorized { r.printer.Println("Not logged in.") return nil } @@ -131,6 +132,7 @@ func (r *Root) loginWithToken() error { profile := r.storage.GetProfile() profile.Token = &globalping.Token{ AccessToken: token, + Expiry: r.utils.Now().Add(math.MaxInt64), } err = r.storage.SaveConfig() if err != nil { diff --git a/cmd/auth_test.go b/cmd/auth_test.go index 4259ed9..b27e8c0 100644 --- a/cmd/auth_test.go +++ b/cmd/auth_test.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "context" + "math" "os" "syscall" "testing" @@ -23,6 +24,9 @@ func Test_Auth_Login_WithToken(t *testing.T) { gbMock := mocks.NewMockClient(ctrl) + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + w := new(bytes.Buffer) r := new(bytes.Buffer) r.WriteString("token\n") @@ -39,7 +43,7 @@ func Test_Auth_Login_WithToken(t *testing.T) { RefreshToken: "oldRefreshToken", } - root := NewRoot(printer, ctx, nil, nil, gbMock, nil, _storage) + root := NewRoot(printer, ctx, nil, utilsMock, gbMock, nil, _storage) gbMock.EXPECT().TokenIntrospection("token").Return(&globalping.IntrospectionResponse{ Active: true, @@ -59,6 +63,7 @@ Logged in as test. assert.Equal(t, &storage.Profile{ Token: &globalping.Token{ AccessToken: "token", + Expiry: defaultCurrentTime.Add(math.MaxInt64), }, }, profile) } diff --git a/cmd/root.go b/cmd/root.go index 3ca665d..76d245b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,7 @@ package cmd import ( + "math" "os" "os/signal" "syscall" @@ -49,12 +50,19 @@ func Execute() { Limit: 1, } t := time.NewTicker(10 * time.Second) + token := profile.Token + if config.GlobalpingToken != "" { + token = &globalping.Token{ + AccessToken: config.GlobalpingToken, + ExpiresIn: math.MaxInt64, + Expiry: time.Now().Add(math.MaxInt64), + } + } globalpingClient := globalping.NewClientWithCacheCleanup(globalping.Config{ - APIURL: config.GlobalpingAPIURL, - AuthURL: config.GlobalpingAuthURL, - DashboardURL: config.GlobalpingDashboardURL, - AuthAccessToken: config.GlobalpingToken, - AuthToken: profile.Token, + APIURL: config.GlobalpingAPIURL, + AuthURL: config.GlobalpingAuthURL, + DashboardURL: config.GlobalpingDashboardURL, + AuthToken: token, OnTokenRefresh: func(token *globalping.Token) { profile.Token = token err := localStorage.SaveConfig() diff --git a/globalping/auth.go b/globalping/auth.go index 0ca540c..b714354 100644 --- a/globalping/auth.go +++ b/globalping/auth.go @@ -1,7 +1,9 @@ package globalping import ( - "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" "net" "net/http" @@ -9,30 +11,25 @@ import ( "strconv" "strings" "time" - - "golang.org/x/oauth2" ) -type Token struct { - // AccessToken is the token that authorizes and authenticates - // the requests. - AccessToken string `json:"access_token"` - - // TokenType is the type of token. - // The Type method returns either this or "Bearer", the default. - TokenType string `json:"token_type,omitempty"` +var timeNow = time.Now - // RefreshToken is a token that's used by the application - // (as opposed to the user) to refresh the access token - // if it expires. - RefreshToken string `json:"refresh_token,omitempty"` +var ( + ErrTypeExchangeFailed = "exchange_failed" + ErrTypeRefreshFailed = "refresh_failed" + ErrTypeRevokeFailed = "revoke_failed" + ErrTypeIntrospectionFailed = "introspection_failed" + ErrTypeInvalidGrant = "invalid_grant" + ErrTypeNotAuthorized = "not_authorized" +) - // Expiry is the optional expiration time of the access token. - // - // If zero, TokenSource implementations will reuse the same - // token forever and RefreshToken or equivalent - // mechanisms for that TokenSource will not be used. - Expiry time.Time `json:"expiry,omitempty"` +type Token struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + Expiry time.Time `json:"expiry,omitempty"` } type AuthorizeError struct { @@ -51,7 +48,7 @@ type AuthorizeResponse struct { } func (c *client) Authorize(callback func(error)) (*AuthorizeResponse, error) { - pkce := oauth2.GenerateVerifier() + verifier := generateVerifier() mux := http.NewServeMux() server := &http.Server{ Handler: mux, @@ -59,7 +56,7 @@ func (c *client) Authorize(callback func(error)) (*AuthorizeResponse, error) { callbackURL := "" mux.HandleFunc("/callback", func(w http.ResponseWriter, req *http.Request) { req.ParseForm() - token, err := c.exchange(req.Form, pkce, callbackURL) + token, err := c.exchange(req.Form, verifier, callbackURL) if err != nil { http.Redirect(w, req, c.dashboardURL+"/authorize/error", http.StatusFound) } else { @@ -68,10 +65,7 @@ func (c *client) Authorize(callback func(error)) (*AuthorizeResponse, error) { go func() { server.Shutdown(req.Context()) if err == nil { - c.token.Store(token) - if c.onTokenRefresh != nil { - c.onTokenRefresh(mapToken(token)) - } + c.updateToken(token) } callback(err) }() @@ -97,26 +91,35 @@ func (c *client) Authorize(callback func(error)) (*AuthorizeResponse, error) { } }() callbackURL = "http://localhost:" + port + "/callback" + q := url.Values{} + q.Set("client_id", c.authClientId) + q.Set("code_challenge", generateS256Challenge(verifier)) + q.Set("code_challenge_method", "S256") + q.Set("response_type", "code") + q.Set("scope", "measurements") + return &AuthorizeResponse{ - AuthorizeURL: c.oauth2.AuthCodeURL("", oauth2.S256ChallengeOption(pkce)), + AuthorizeURL: c.authURL + "/oauth/authorize?" + q.Encode(), CallbackURL: callbackURL, }, nil } func (c *client) TokenIntrospection(token string) (*IntrospectionResponse, error) { if token == "" { - var err error - token, _, err = c.accessToken() + t, err := c.getToken() if err != nil { return nil, &AuthorizeError{ - ErrorType: "not_authorized", + ErrorType: ErrTypeNotAuthorized, Description: err.Error(), } } + if t != nil { + token = t.AccessToken + } } if token == "" { return nil, &AuthorizeError{ - ErrorType: "not_authorized", + ErrorType: ErrTypeNotAuthorized, Description: "client is not authorized", } } @@ -124,7 +127,9 @@ func (c *client) TokenIntrospection(token string) (*IntrospectionResponse, error } func (c *client) Logout() error { - t := c.token.Load() + c.mu.RLock() + t := c.token + c.mu.RUnlock() if t == nil { return nil } @@ -132,17 +137,11 @@ func (c *client) Logout() error { if err != nil { return err } - c.mu.Lock() - defer c.mu.Unlock() - c.tokenSource = nil - c.token.Store(nil) - if c.onTokenRefresh != nil { - c.onTokenRefresh(nil) - } + c.updateToken(nil) return nil } -func (c *client) exchange(form url.Values, pkce string, redirect string) (*oauth2.Token, error) { +func (c *client) exchange(form url.Values, verifier string, redirect string) (*Token, error) { if form.Get("error") != "" { return nil, &AuthorizeError{ ErrorType: form.Get("error"), @@ -156,36 +155,191 @@ func (c *client) exchange(form url.Values, pkce string, redirect string) (*oauth Description: "missing code in response", } } - return c.oauth2.Exchange( - context.Background(), - code, - oauth2.VerifierOption(pkce), - oauth2.SetAuthURLParam("redirect_uri", redirect), - ) + q := url.Values{} + q.Set("client_id", c.authClientId) + q.Set("client_secret", c.authClientSecret) + q.Set("code", code) + q.Set("code_verifier", verifier) + q.Set("grant_type", "authorization_code") + q.Set("redirect_uri", redirect) + req, err := http.NewRequest("POST", c.authURL+"/oauth/token", strings.NewReader(q.Encode())) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: ErrTypeExchangeFailed, + Description: err.Error(), + } + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Content-Length", strconv.Itoa(len(q.Encode()))) + resp, err := c.http.Do(req) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: ErrTypeExchangeFailed, + Description: err.Error(), + } + } + if resp.StatusCode != http.StatusOK { + err := &AuthorizeError{ + Code: resp.StatusCode, + ErrorType: ErrTypeExchangeFailed, + Description: resp.Status, + } + json.NewDecoder(resp.Body).Decode(err) + return nil, err + } + t := &Token{} + err = json.NewDecoder(resp.Body).Decode(t) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: ErrTypeExchangeFailed, + Description: err.Error(), + } + } + if t.TokenType == "" { + t.TokenType = "Bearer" + } + if t.ExpiresIn != 0 { + t.Expiry = timeNow().Add(time.Duration(t.ExpiresIn) * time.Second) + } + return t, nil } -func (c *client) accessToken() (string, string, error) { +func (c *client) getToken() (*Token, error) { c.mu.RLock() defer c.mu.RUnlock() - if c.tokenSource == nil { - return "", "", nil + if c.token == nil { + return nil, nil + } + if !c.token.Expiry.Before(timeNow()) { + return c.token, nil + } + if c.token.RefreshToken == "" { + return nil, &AuthorizeError{ + ErrorType: "refresh_failed", + Description: "empty refresh token", + } } - token, err := c.tokenSource.Token() + t, err := c.refreshToken(c.token.RefreshToken) if err != nil { - e, ok := err.(*oauth2.RetrieveError) - if ok && e.ErrorCode == "invalid_grant" && c.onTokenRefresh != nil { + e, ok := err.(*AuthorizeError) + if ok && e.ErrorType == ErrTypeInvalidGrant && c.onTokenRefresh != nil { + c.onTokenRefresh(nil) + } + return nil, err + } + c.token = t + if c.onTokenRefresh != nil { + c.onTokenRefresh(&Token{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + ExpiresIn: t.ExpiresIn, + Expiry: t.Expiry, + }) + } + return t, nil +} + +func (c *client) updateToken(t *Token) { + c.mu.Lock() + defer c.mu.Unlock() + c.token = t + if c.onTokenRefresh != nil { + if t == nil { c.onTokenRefresh(nil) + } else { + c.onTokenRefresh(&Token{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + ExpiresIn: t.ExpiresIn, + Expiry: t.Expiry, + }) + } + } +} + +func (c *client) tryToRefreshToken(refreshToken string) bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.token == nil { + return false + } + // must have been called by a different goroutine + if c.token.RefreshToken != refreshToken { + return false + } + token, err := c.refreshToken(c.token.RefreshToken) + if err != nil { + e, ok := err.(*AuthorizeError) + // If the refresh token is invalid, clear the token + if ok && e.ErrorType == ErrTypeInvalidGrant && c.onTokenRefresh != nil { + c.token = nil + if c.onTokenRefresh != nil { + c.onTokenRefresh(nil) + } + } + return false + } + c.token = token + if c.onTokenRefresh != nil { + c.onTokenRefresh(&Token{ + AccessToken: token.AccessToken, + TokenType: token.TokenType, + RefreshToken: token.RefreshToken, + ExpiresIn: token.ExpiresIn, + Expiry: token.Expiry, + }) + } + return true +} + +func (c *client) refreshToken(token string) (*Token, error) { + q := url.Values{} + q.Set("client_id", c.authClientId) + q.Set("client_secret", c.authClientSecret) + q.Set("refresh_token", token) + q.Set("grant_type", "refresh_token") + req, err := http.NewRequest("POST", c.authURL+"/oauth/token", strings.NewReader(q.Encode())) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: ErrTypeRefreshFailed, + Description: err.Error(), + } + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Content-Length", strconv.Itoa(len(q.Encode()))) + resp, err := c.http.Do(req) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: ErrTypeRefreshFailed, + Description: err.Error(), + } + } + if resp.StatusCode != http.StatusOK { + err := &AuthorizeError{ + Code: resp.StatusCode, + ErrorType: ErrTypeRefreshFailed, + Description: resp.Status, } - return "", "", err + json.NewDecoder(resp.Body).Decode(err) + return nil, err } - curr := c.token.Load() - if curr != nil && token.AccessToken != curr.AccessToken { - c.token.Store(token) - if c.onTokenRefresh != nil { - c.onTokenRefresh(mapToken(token)) + t := &Token{} + err = json.NewDecoder(resp.Body).Decode(t) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: ErrTypeRefreshFailed, + Description: err.Error(), } } - return token.AccessToken, token.Type(), nil + if t.TokenType == "" { + t.TokenType = "Bearer" + } + if t.ExpiresIn != 0 { + t.Expiry = timeNow().Add(time.Duration(t.ExpiresIn) * time.Second) + } + return t, nil } // https://datatracker.ietf.org/doc/html/rfc7662#section-2.1 @@ -212,7 +366,7 @@ func (c *client) introspection(token string) (*IntrospectionResponse, error) { req, err := http.NewRequest("POST", c.authURL+"/oauth/token/introspect", strings.NewReader(form)) if err != nil { return nil, &AuthorizeError{ - ErrorType: "introspection_failed", + ErrorType: ErrTypeIntrospectionFailed, Description: err.Error(), } } @@ -221,14 +375,14 @@ func (c *client) introspection(token string) (*IntrospectionResponse, error) { resp, err := c.http.Do(req) if err != nil { return nil, &AuthorizeError{ - ErrorType: "introspection_failed", + ErrorType: ErrTypeIntrospectionFailed, Description: err.Error(), } } if resp.StatusCode != http.StatusOK { err := &AuthorizeError{ Code: resp.StatusCode, - ErrorType: "introspection_failed", + ErrorType: ErrTypeIntrospectionFailed, Description: resp.Status, } json.NewDecoder(resp.Body).Decode(err) @@ -238,7 +392,7 @@ func (c *client) introspection(token string) (*IntrospectionResponse, error) { err = json.NewDecoder(resp.Body).Decode(ires) if err != nil { return nil, &AuthorizeError{ - ErrorType: "introspection_failed", + ErrorType: ErrTypeIntrospectionFailed, Description: err.Error(), } } @@ -253,7 +407,7 @@ func (c *client) RevokeToken(token string) error { req, err := http.NewRequest("POST", c.authURL+"/oauth/token/revoke", strings.NewReader(form)) if err != nil { return &AuthorizeError{ - ErrorType: "revoke_failed", + ErrorType: ErrTypeRevokeFailed, Description: err.Error(), } } @@ -262,14 +416,14 @@ func (c *client) RevokeToken(token string) error { resp, err := c.http.Do(req) if err != nil { return &AuthorizeError{ - ErrorType: "revoke_failed", + ErrorType: ErrTypeRevokeFailed, Description: err.Error(), } } if resp.StatusCode != http.StatusOK { err := &AuthorizeError{ Code: resp.StatusCode, - ErrorType: "revoke_failed", + ErrorType: ErrTypeRevokeFailed, Description: resp.Status, } json.NewDecoder(resp.Body).Decode(err) @@ -278,11 +432,15 @@ func (c *client) RevokeToken(token string) error { return nil } -func mapToken(t *oauth2.Token) *Token { - return &Token{ - AccessToken: t.AccessToken, - TokenType: t.TokenType, - RefreshToken: t.RefreshToken, - Expiry: t.Expiry, +func generateVerifier() string { + data := make([]byte, 32) + if _, err := rand.Read(data); err != nil { + panic(err) } + return base64.RawURLEncoding.EncodeToString(data) +} + +func generateS256Challenge(verifier string) string { + sha := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(sha[:]) } diff --git a/globalping/auth_test.go b/globalping/auth_test.go index ddef713..8f0c9f0 100644 --- a/globalping/auth_test.go +++ b/globalping/auth_test.go @@ -40,9 +40,7 @@ func Test_Authorize(t *testing.T) { w.Header().Set("Content-Type", "application/json") _, err = w.Write(getTokenJSON()) - if err != nil { - t.Fatal(err) - } + assert.Nil(t, err) return } t.Fatalf("unexpected request to %s", r.URL.Path) @@ -58,6 +56,7 @@ func Test_Authorize(t *testing.T) { AccessToken: "token", TokenType: "bearer", RefreshToken: "refresh", + ExpiresIn: 3600, Expiry: _token.Expiry, }, _token) }, @@ -217,6 +216,7 @@ func Test_TokenIntrospection_Token_Refreshed(t *testing.T) { AccessToken: "new_token", TokenType: "bearer", RefreshToken: "new_refresh_token", + ExpiresIn: 3600, Expiry: token.Expiry, }, token) } @@ -275,6 +275,16 @@ func Test_TokenIntrospection_With_Token(t *testing.T) { assert.False(t, onTokenRefreshCalled) } +func Test_TokenIntrospection_No_Token(t *testing.T) { + client := NewClient(Config{}) + res, err := client.TokenIntrospection("") + assert.Nil(t, res) + e, ok := err.(*AuthorizeError) + assert.True(t, ok) + assert.Equal(t, ErrTypeNotAuthorized, e.ErrorType) + assert.Equal(t, "client is not authorized", e.Description) +} + func Test_Logout(t *testing.T) { isCalled := false now := time.Now() @@ -312,9 +322,7 @@ func Test_Logout(t *testing.T) { }, }) err := client.Logout() - if err != nil { - t.Fatal(err) - } + assert.Nil(t, err) assert.True(t, isCalled) assert.True(t, onTokenRefreshCalled) } @@ -370,9 +378,7 @@ func Test_Logout_No_RefreshToken(t *testing.T) { }, }) err := client.Logout() - if err != nil { - t.Fatal(err) - } + assert.Nil(t, err) assert.True(t, onTokenRefreshCalled) } @@ -388,9 +394,12 @@ func Test_Logout_AccessToken_Is_Set(t *testing.T) { AuthClientSecret: "", AuthURL: server.URL, DashboardURL: server.URL, - AuthAccessToken: "tok3n", + AuthToken: &Token{ + AccessToken: "tok3n", + Expiry: time.Now().Add(time.Hour), + }, OnTokenRefresh: func(token *Token) { - onTokenRefreshCalled = true + assert.Nil(t, token) }, }) err := client.Logout() diff --git a/globalping/client.go b/globalping/client.go index 1148074..edc309e 100644 --- a/globalping/client.go +++ b/globalping/client.go @@ -1,13 +1,9 @@ package globalping import ( - "context" "net/http" "sync" - "sync/atomic" "time" - - "golang.org/x/oauth2" ) type Client interface { @@ -50,7 +46,6 @@ type Config struct { AuthURL string AuthClientID string AuthClientSecret string - AuthAccessToken string // If set, this token will be used for API requests AuthToken *Token OnTokenRefresh func(*Token) @@ -68,10 +63,10 @@ type client struct { http *http.Client cache map[string]*CacheEntry - oauth2 *oauth2.Config - token atomic.Pointer[oauth2.Token] - tokenSource oauth2.TokenSource - onTokenRefresh func(*Token) + authClientId string + authClientSecret string + token *Token + onTokenRefresh func(*Token) apiURL string authURL string @@ -85,23 +80,15 @@ type client struct { // If you want a cache cleanup goroutine, use NewClientWithCacheCleanup. func NewClient(config Config) Client { c := &client{ - mu: sync.RWMutex{}, - oauth2: &oauth2.Config{ - ClientID: config.AuthClientID, - ClientSecret: config.AuthClientSecret, - Scopes: []string{"measurements"}, - Endpoint: oauth2.Endpoint{ - AuthURL: config.AuthURL + "/oauth/authorize", - TokenURL: config.AuthURL + "/oauth/token", - AuthStyle: oauth2.AuthStyleInParams, - }, - }, - onTokenRefresh: config.OnTokenRefresh, - apiURL: config.APIURL, - authURL: config.AuthURL, - dashboardURL: config.DashboardURL, - userAgent: config.UserAgent, - cache: map[string]*CacheEntry{}, + mu: sync.RWMutex{}, + authClientId: config.AuthClientID, + authClientSecret: config.AuthClientSecret, + onTokenRefresh: config.OnTokenRefresh, + apiURL: config.APIURL, + authURL: config.AuthURL, + dashboardURL: config.DashboardURL, + userAgent: config.UserAgent, + cache: map[string]*CacheEntry{}, } if config.HTTPClient != nil { c.http = config.HTTPClient @@ -110,21 +97,16 @@ func NewClient(config Config) Client { Timeout: 30 * time.Second, } } - if config.AuthAccessToken != "" { - c.tokenSource = oauth2.StaticTokenSource(&oauth2.Token{AccessToken: config.AuthAccessToken}) - } else if config.AuthToken != nil { - t := &oauth2.Token{ + if config.AuthToken != nil { + c.token = &Token{ AccessToken: config.AuthToken.AccessToken, TokenType: config.AuthToken.TokenType, RefreshToken: config.AuthToken.RefreshToken, + ExpiresIn: config.AuthToken.ExpiresIn, Expiry: config.AuthToken.Expiry, } - c.token.Store(t) - if config.AuthToken.RefreshToken == "" { - c.tokenSource = oauth2.StaticTokenSource(&oauth2.Token{AccessToken: config.AuthToken.AccessToken}) - } else { - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, c.http) - c.tokenSource = c.oauth2.TokenSource(ctx, t) + if c.token.TokenType == "" { + c.token.TokenType = "Bearer" } } return c diff --git a/globalping/limits.go b/globalping/limits.go index e107c67..f89dbaa 100644 --- a/globalping/limits.go +++ b/globalping/limits.go @@ -56,12 +56,12 @@ func (c *client) Limits() (*LimitsResponse, error) { if err != nil { return nil, &LimitsError{Message: "failed to create request - please report this bug"} } - token, tokenType, err := c.accessToken() + token, err := c.getToken() if err != nil { return nil, &LimitsError{Message: "failed to get token: " + err.Error()} } - if token != "" { - req.Header.Set("Authorization", tokenType+" "+token) + if token != nil { + req.Header.Set("Authorization", token.TokenType+" "+token.AccessToken) } resp, err := c.http.Do(req) if err != nil { diff --git a/globalping/limits_test.go b/globalping/limits_test.go index 1d0539f..6f4896b 100644 --- a/globalping/limits_test.go +++ b/globalping/limits_test.go @@ -49,6 +49,7 @@ func Test_Limits(t *testing.T) { APIURL: server.URL, AuthToken: &Token{ AccessToken: "tok3n", + TokenType: "Bearer", Expiry: time.Now().Add(time.Hour), }, }) diff --git a/globalping/measurements.go b/globalping/measurements.go index 8471dac..e43f7dc 100644 --- a/globalping/measurements.go +++ b/globalping/measurements.go @@ -17,6 +17,8 @@ var ( moreCreditsRequiredAuthErr = "You only have %s remaining, and %d were required. Try requesting fewer probes or wait %s for the rate limit to reset. You can get higher limits by sponsoring us or hosting probes." noCreditsNoAuthErr = "You have run out of credits for this session. You can wait %s for the rate limit to reset or get higher limits by creating an account. Sign up at https://globalping.io" noCreditsAuthErr = "You have run out of credits for this session. You can wait %s for the rate limit to reset or get higher limits by sponsoring us or hosting probes." + invalidRefreshTokenErr = "You have been signed out by the API. Please try signing in again." + invalidTokenErr = "Your access token has been rejected by the API. Try signing in with a new token." ) var ( @@ -37,12 +39,12 @@ func (c *client) CreateMeasurement(measurement *MeasurementCreate) (*Measurement req.Header.Set("Accept-Encoding", "br") req.Header.Set("Content-Type", "application/json") - token, tokenType, err := c.accessToken() + token, err := c.getToken() if err != nil { return nil, &MeasurementError{Message: "failed to get token: " + err.Error()} } - if token != "" { - req.Header.Set("Authorization", tokenType+" "+token) + if token != nil { + req.Header.Set("Authorization", token.TokenType+" "+token.AccessToken) } resp, err := c.http.Do(req) @@ -74,11 +76,19 @@ func (c *client) CreateMeasurement(measurement *MeasurementCreate) (*Measurement } if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { - token, _, e := c.accessToken() - if e == nil && token != "" { - err.Code = StatusUnauthorizedWithTokenRefreshed + if token != nil { + if token.RefreshToken == "" { + err.Message = invalidTokenErr + return nil, err + } + if c.tryToRefreshToken(token.RefreshToken) { + err.Code = StatusUnauthorizedWithTokenRefreshed + return nil, err + } + err.Message = invalidRefreshTokenErr + return nil, err } - err.Message = "unauthorized: " + data.Error.Message + err.Message = data.Error.Message return nil, err } @@ -93,7 +103,7 @@ func (c *client) CreateMeasurement(measurement *MeasurementCreate) (*Measurement creditsRemaining, _ := strconv.ParseInt(resp.Header.Get("X-Credits-Remaining"), 10, 64) requestCost, _ := strconv.ParseInt(resp.Header.Get("X-Request-Cost"), 10, 64) remaining := rateLimitRemaining + creditsRemaining - if token == "" { + if token == nil { if remaining > 0 { err.Message = fmt.Sprintf(moreCreditsRequiredNoAuthErr, utils.Pluralize(remaining, "credit"), requestCost, utils.FormatSeconds(rateLimitReset)) return nil, err diff --git a/globalping/measurements_test.go b/globalping/measurements_test.go index e16f6c1..0b855ef 100644 --- a/globalping/measurements_test.go +++ b/globalping/measurements_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/andybalholm/brotli" @@ -30,8 +31,11 @@ func Test_CreateMeasurement_Authorized(t *testing.T) { server := generateServerAuthorized(`{"id":"abcd","probesCount":1}`) defer server.Close() client := NewClient(Config{ - AuthAccessToken: "secret", - APIURL: server.URL, + AuthToken: &Token{ + AccessToken: "secret", + Expiry: time.Now().Add(1 * time.Hour), + }, + APIURL: server.URL, }) opts := &MeasurementCreate{} @@ -53,7 +57,223 @@ func Test_CreateMeasurement_AuthorizedError(t *testing.T) { res, err := client.CreateMeasurement(opts) assert.Nil(t, res) - assert.EqualError(t, err, "unauthorized: Unauthorized.") + assert.EqualError(t, err, "Unauthorized.") +} + +func Test_CreateMeasurement_TokenRefreshed(t *testing.T) { + now := time.Now() + timeNow = func() time.Time { + return now + } + defer func() { + timeNow = time.Now + }() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/oauth/token" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "", r.Form.Get("client_id")) + assert.Equal(t, "", r.Form.Get("client_secret")) + assert.Equal(t, "refresh_token", r.Form.Get("grant_type")) + assert.Equal(t, "refresh_tok3n", r.Form.Get("refresh_token")) + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write([]byte(`{"access_token":"new_token","token_type":"Bearer","refresh_token":"new_refresh_token","expires_in":3600}`)) + if err != nil { + t.Fatal(err) + } + return + } + if r.URL.Path == "/measurements" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + _, err := w.Write([]byte(`{"id":"abcd","probesCount":1}`)) + if err != nil { + t.Fatal(err) + } + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + client := NewClient(Config{ + APIURL: server.URL, + AuthURL: server.URL, + AuthClientID: "", + AuthClientSecret: "", + AuthToken: &Token{ + AccessToken: "access_token", + RefreshToken: "refresh_tok3n", + Expiry: time.Now().Add(-1 * time.Hour), + }, + OnTokenRefresh: func(_t *Token) { + assert.Equal(t, &Token{ + AccessToken: "new_token", + TokenType: "Bearer", + RefreshToken: "new_refresh_token", + ExpiresIn: 3600, + Expiry: now.Add(3600 * time.Second), + }, _t) + }, + }) + + opts := &MeasurementCreate{} + res, err := client.CreateMeasurement(opts) + assert.Nil(t, err) + assert.Equal(t, "abcd", res.ID) +} + +func Test_CreateMeasurement_Unauthorized_TokenRefreshed(t *testing.T) { + now := time.Now() + timeNow = func() time.Time { + return now + } + defer func() { + timeNow = time.Now + }() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/oauth/token" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "", r.Form.Get("client_id")) + assert.Equal(t, "", r.Form.Get("client_secret")) + assert.Equal(t, "refresh_token", r.Form.Get("grant_type")) + assert.Equal(t, "refresh_tok3n", r.Form.Get("refresh_token")) + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write([]byte(`{"access_token":"new_token","token_type":"Bearer","refresh_token":"new_refresh_token","expires_in":3600}`)) + if err != nil { + t.Fatal(err) + } + return + } + if r.URL.Path == "/measurements" { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": {"type": "unauthorized", "message": "Unauthorized."}}`)) + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + client := NewClient(Config{ + APIURL: server.URL, + AuthURL: server.URL, + AuthClientID: "", + AuthClientSecret: "", + AuthToken: &Token{ + AccessToken: "access_token", + RefreshToken: "refresh_tok3n", + Expiry: time.Now().Add(1 * time.Hour), + }, + OnTokenRefresh: func(_t *Token) { + assert.Equal(t, &Token{ + AccessToken: "new_token", + TokenType: "Bearer", + RefreshToken: "new_refresh_token", + ExpiresIn: 3600, + Expiry: now.Add(3600 * time.Second), + }, _t) + }, + }) + + opts := &MeasurementCreate{} + res, err := client.CreateMeasurement(opts) + assert.Nil(t, res) + e, ok := err.(*MeasurementError) + assert.True(t, ok) + assert.Equal(t, StatusUnauthorizedWithTokenRefreshed, e.Code) + assert.Equal(t, "Unauthorized.", e.Message) +} + +func Test_CreateMeasurement_Unauthorized_Token_Not_Refreshed(t *testing.T) { + now := time.Now() + timeNow = func() time.Time { + return now + } + defer func() { + timeNow = time.Now + }() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/oauth/token" { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "invalid_grant", "error_description": "Invalid refresh token."}`)) + return + } + if r.URL.Path == "/measurements" { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": {"type": "unauthorized", "message": "Unauthorized."}}`)) + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + isOnTokenRefreshCalled := false + client := NewClient(Config{ + APIURL: server.URL, + AuthURL: server.URL, + AuthClientID: "", + AuthClientSecret: "", + AuthToken: &Token{ + AccessToken: "access_token", + RefreshToken: "refresh_tok3n", + Expiry: time.Now().Add(1 * time.Hour), + }, + OnTokenRefresh: func(_t *Token) { + isOnTokenRefreshCalled = true + assert.Nil(t, _t) + }, + }) + + opts := &MeasurementCreate{} + res, err := client.CreateMeasurement(opts) + assert.Nil(t, res) + assert.EqualError(t, err, "You have been signed out by the API. Please try signing in again.") + assert.True(t, isOnTokenRefreshCalled) +} +func Test_CreateMeasurement_Unauthorized_NoRefreshToken(t *testing.T) { + now := time.Now() + timeNow = func() time.Time { + return now + } + defer func() { + timeNow = time.Now + }() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": {"type": "unauthorized", "message": "Unauthorized."}}`)) + })) + defer server.Close() + client := NewClient(Config{ + APIURL: server.URL, + AuthURL: server.URL, + AuthClientID: "", + AuthClientSecret: "", + AuthToken: &Token{ + AccessToken: "access_token", + Expiry: time.Now().Add(1 * time.Hour), + }, + OnTokenRefresh: func(_t *Token) { + t.Fatal("should not be called") + }, + }) + + opts := &MeasurementCreate{} + res, err := client.CreateMeasurement(opts) + assert.Nil(t, res) + assert.EqualError(t, err, invalidTokenErr) } func Test_CreateMeasurement_MoreCreditsRequiredNoAuthError(t *testing.T) { @@ -105,8 +325,11 @@ func Test_CreateMeasurement_MoreCreditsRequiredAuthError(t *testing.T) { defer server.Close() client := NewClient(Config{ - AuthAccessToken: "secret", - APIURL: server.URL, + AuthToken: &Token{ + AccessToken: "secret", + Expiry: time.Now().Add(1 * time.Hour), + }, + APIURL: server.URL, }) opts := &MeasurementCreate{} @@ -160,8 +383,11 @@ func Test_CreateMeasurement_NoCreditsAuthError(t *testing.T) { defer server.Close() client := NewClient(Config{ - AuthAccessToken: "secret", - APIURL: server.URL, + AuthToken: &Token{ + AccessToken: "secret", + Expiry: time.Now().Add(1 * time.Hour), + }, + APIURL: server.URL, }) opts := &MeasurementCreate{} _, err := client.CreateMeasurement(opts) diff --git a/go.mod b/go.mod index 4741805..16313e9 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 go.uber.org/mock v0.4.0 - golang.org/x/oauth2 v0.23.0 golang.org/x/term v0.18.0 ) diff --git a/go.sum b/go.sum index f1416ac..9f59a1a 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,6 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/icza/backscanner v0.0.0-20240221180818-f23e3ba0e79f h1:EKPpaKkARuHjoV/ZKzk3vqbSJXULRSivDCQhL+tF77Y= github.com/icza/backscanner v0.0.0-20240221180818-f23e3ba0e79f/go.mod h1:GYeBD1CF7AqnKZK+UCytLcY3G+UKo0ByXX/3xfdNyqQ= github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k= @@ -51,8 +49,6 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= -golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=