From 4547401b425a4c2af2994780c3b08ba07f386ac9 Mon Sep 17 00:00:00 2001 From: nenaraab Date: Fri, 10 Dec 2021 10:30:46 +0100 Subject: [PATCH] Introduce cache for tokens requested from identity service (#46) --- auth/middleware_test.go | 2 +- tokenclient/tokenFlows.go | 87 +++++++++++++++++++++++++++++----- tokenclient/tokenFlows_test.go | 58 ++++++++++++++++++----- 3 files changed, 122 insertions(+), 25 deletions(-) diff --git a/auth/middleware_test.go b/auth/middleware_test.go index da7baeb..f76a292 100644 --- a/auth/middleware_test.go +++ b/auth/middleware_test.go @@ -85,7 +85,7 @@ func TestEnd2End(t *testing.T) { name: "before validity", header: oidcMockServer.DefaultHeaders(), claims: mocks.NewOIDCClaimsBuilder(oidcMockServer.DefaultClaims()). - NotBefore(time.Now().Add(1 * time.Minute)). + NotBefore(time.Now().Add(2 * time.Minute)). Build(), wantErr: true, }, { diff --git a/tokenclient/tokenFlows.go b/tokenclient/tokenFlows.go index 7d5ff4c..9d21f8f 100644 --- a/tokenclient/tokenFlows.go +++ b/tokenclient/tokenFlows.go @@ -7,12 +7,15 @@ import ( "context" "encoding/json" "fmt" + "github.com/patrickmn/go-cache" "github.com/sap/cloud-security-client-go/env" "github.com/sap/cloud-security-client-go/httpclient" - "io/ioutil" + "io" + "log" "net/http" "net/url" "strings" + "time" ) // Options allows configuration http(s) client @@ -31,6 +34,27 @@ type TokenFlows struct { identity env.Identity Options Options tokenURI string + cache *cache.Cache +} + +type request struct { + http.Request + key string +} + +func (r *request) cacheKey() (string, error) { + if r.key == "" { + bodyReader, err := r.GetBody() + if err != nil { + return "", fmt.Errorf("unexpected error, can't read request body: %w", err) + } + params, err := io.ReadAll(bodyReader) + if err != nil { + return "", fmt.Errorf("unexpected error, can't read request body: %w", err) + } + r.key = fmt.Sprintf("%v?%v", r.URL, string(params)) + } + return r.key, nil } // RequestFailedError represents a HTTP server error @@ -68,6 +92,7 @@ func NewTokenFlows(identity env.Identity, options Options) (*TokenFlows, error) identity: identity, tokenURI: identity.GetURL() + tokenEndpoint, Options: options, + cache: cache.New(15*time.Minute, 10*time.Minute), //nolint:gomnd } if options.HTTPClient == nil { tlsConfig, err := httpclient.DefaultTLSConfig(identity) @@ -86,7 +111,7 @@ func NewTokenFlows(identity env.Identity, options Options) (*TokenFlows, error) // // ctx carries the request context like the deadline or other values that should be shared across API boundaries. // customerTenantURL like "https://custom.accounts400.ondemand.com" gives the host of the customers ias tenant -// Options allows to provide a request context and optionally additional request parameters +// options allows to provide additional request parameters func (t *TokenFlows) ClientCredentials(ctx context.Context, customerTenantURL string, options RequestOptions) (string, error) { data := url.Values{} data.Set(clientIDParameter, t.identity.GetClientID()) @@ -107,14 +132,7 @@ func (t *TokenFlows) ClientCredentials(ctx context.Context, customerTenantURL st } r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - var tokenRes tokenResponse - err = t.performRequest(r, &tokenRes) - if err != nil { - return "", err - } else if tokenRes.Token == "" { - return "", fmt.Errorf("error parsing requested client credentials token: no 'access_token' property provided") - } - return tokenRes.Token, nil + return t.getOrRequestToken(request{Request: *r}) } func (t *TokenFlows) getURL(customerTenantURL string) (string, error) { @@ -128,14 +146,57 @@ func (t *TokenFlows) getURL(customerTenantURL string) (string, error) { return "", fmt.Errorf("customer tenant url '%v' can't be parsed: %w", customerTenantURL, err) } -func (t *TokenFlows) performRequest(r *http.Request, v interface{}) error { - res, err := t.Options.HTTPClient.Do(r) +func (t *TokenFlows) getOrRequestToken(r request) (string, error) { + // token cached? + cachedToken := t.readFromCache(&r) + if cachedToken != "" { + return cachedToken, nil + } + + // request token + var tokenRes tokenResponse + err := t.performRequest(r, &tokenRes) + if err != nil { + return "", err + } + if tokenRes.Token == "" { + return "", fmt.Errorf("error parsing requested client credentials token: no 'access_token' property provided") + } + + // cache and return retrieved token + t.writeToCache(r, tokenRes.Token) + return tokenRes.Token, err +} + +func (t *TokenFlows) readFromCache(r *request) string { + cacheKey, err := r.cacheKey() + if err != nil { + return "" + } + cachedEncodedToken, found := t.cache.Get(cacheKey) + if !found { + return "" + } + return fmt.Sprintf("%v", cachedEncodedToken) +} + +func (t *TokenFlows) writeToCache(r request, token string) { + cacheKey, err := r.cacheKey() + if err != nil { + log.Fatalf("Write to Cache is skipped. Unexpected error to determine cache key: %s", err.Error()) + return + } + t.cache.SetDefault(cacheKey, token) +} + +func (t *TokenFlows) performRequest(r request, v interface{}) error { + res, err := t.Options.HTTPClient.Do(&r.Request) if err != nil { return fmt.Errorf("request to '%v' failed: %w", r.URL, err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { - body, _ := ioutil.ReadAll(res.Body) + body, _ := io.ReadAll(res.Body) return &RequestFailedError{res.StatusCode, *r.URL, string(body)} } if err = json.NewDecoder(res.Body).Decode(v); err != nil { diff --git a/tokenclient/tokenFlows_test.go b/tokenclient/tokenFlows_test.go index 1f119e6..267fd7a 100644 --- a/tokenclient/tokenFlows_test.go +++ b/tokenclient/tokenFlows_test.go @@ -18,6 +18,9 @@ import ( "time" ) +var tokenRequestHandlerHitCounter int +var dummyToken = "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo" //nolint:gosec + var clientSecretConfig = &env.DefaultIdentity{ ClientID: "09932670-9440-445d-be3e-432a97d7e2ef", ClientSecret: "[the_CLIENT.secret:3[/abc", @@ -35,7 +38,7 @@ func TestNewTokenFlows_setupDefaultHttpsClientFails(t *testing.T) { } func TestClientCredentialsTokenFlow_FailsWithTimeout(t *testing.T) { - server := setupNewTLSServer(tokenHandler) + server := setupNewTLSServer(t, tokenHandler) defer server.Close() tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()}) @@ -46,7 +49,7 @@ func TestClientCredentialsTokenFlow_FailsWithTimeout(t *testing.T) { } func TestClientCredentialsTokenFlow_FailsNoData(t *testing.T) { - server := setupNewTLSServer(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("no json")) }) + server := setupNewTLSServer(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("no json")) }) defer server.Close() tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()}) @@ -55,7 +58,7 @@ func TestClientCredentialsTokenFlow_FailsNoData(t *testing.T) { } func TestClientCredentialsTokenFlow_FailsNoJson(t *testing.T) { - server := setupNewTLSServer(func(w http.ResponseWriter, r *http.Request) {}) + server := setupNewTLSServer(t, func(w http.ResponseWriter, r *http.Request) {}) defer server.Close() tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()}) @@ -64,7 +67,7 @@ func TestClientCredentialsTokenFlow_FailsNoJson(t *testing.T) { } func TestClientCredentialsTokenFlow_FailsUnexpectedJson(t *testing.T) { - server := setupNewTLSServer(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("{\"a\":\"b\"}")) }) + server := setupNewTLSServer(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("{\"a\":\"b\"}")) }) defer server.Close() tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()}) @@ -73,9 +76,10 @@ func TestClientCredentialsTokenFlow_FailsUnexpectedJson(t *testing.T) { } func TestClientCredentialsTokenFlow_FailsWithUnauthenticated(t *testing.T) { - server := setupNewTLSServer(func(w http.ResponseWriter, r *http.Request) { + server := setupNewTLSServer(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(401) w.Write([]byte("unauthenticated client")) //nolint:errcheck + tokenRequestHandlerHitCounter++ }) defer server.Close() tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()}) @@ -86,10 +90,15 @@ func TestClientCredentialsTokenFlow_FailsWithUnauthenticated(t *testing.T) { if !errors.As(err, &requestFailed) || requestFailed.StatusCode != 401 { assert.Fail(t, "error not of type ClientError") } + assert.Equal(t, 1, tokenRequestHandlerHitCounter) + assert.Equal(t, 0, tokenFlows.cache.ItemCount()) + + _, _ = tokenFlows.ClientCredentials(context.TODO(), server.URL, RequestOptions{}) + assert.Equal(t, 2, tokenRequestHandlerHitCounter) } func TestClientCredentialsTokenFlow_FailsWithCustomerUrlWithoutScheme(t *testing.T) { - server := setupNewTLSServer(tokenHandler) + server := setupNewTLSServer(t, tokenHandler) defer server.Close() tokenFlows, _ := NewTokenFlows(clientSecretConfig, Options{HTTPClient: server.Client()}) @@ -99,7 +108,7 @@ func TestClientCredentialsTokenFlow_FailsWithCustomerUrlWithoutScheme(t *testing } func TestClientCredentialsTokenFlow_FailsWithInvalidCustomerUrl(t *testing.T) { - server := setupNewTLSServer(tokenHandler) + server := setupNewTLSServer(t, tokenHandler) defer server.Close() tokenFlows, _ := NewTokenFlows(clientSecretConfig, Options{HTTPClient: server.Client()}) @@ -109,12 +118,34 @@ func TestClientCredentialsTokenFlow_FailsWithInvalidCustomerUrl(t *testing.T) { } func TestClientCredentialsTokenFlow_Succeeds(t *testing.T) { - server := setupNewTLSServer(tokenHandler) + server := setupNewTLSServer(t, tokenHandler) tokenFlows, _ := NewTokenFlows(&env.DefaultIdentity{ ClientID: "09932670-9440-445d-be3e-432a97d7e2ef"}, Options{HTTPClient: server.Client()}) token, err := tokenFlows.ClientCredentials(context.TODO(), server.URL, RequestOptions{}) - assertToken(t, "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo", token, err) + assertToken(t, dummyToken, token, err) +} + +func TestClientCredentialsTokenFlow_ReadFromCache(t *testing.T) { + server := setupNewTLSServer(t, tokenHandler) + tokenFlows, _ := NewTokenFlows(&env.DefaultIdentity{ + ClientID: "09932670-9440-445d-be3e-432a97d7e2ef"}, Options{HTTPClient: server.Client()}) + + assert.Equal(t, 0, tokenRequestHandlerHitCounter) + assert.Equal(t, 0, tokenFlows.cache.ItemCount()) + + token, err := tokenFlows.ClientCredentials(context.TODO(), server.URL, RequestOptions{}) + assert.Equal(t, 1, tokenRequestHandlerHitCounter) + assert.Equal(t, 1, tokenFlows.cache.ItemCount()) + assertToken(t, dummyToken, token, err) + + token, err = tokenFlows.ClientCredentials(context.TODO(), server.URL, RequestOptions{}) + assert.Equal(t, 1, tokenRequestHandlerHitCounter) + assert.Equal(t, 1, tokenFlows.cache.ItemCount()) + assertToken(t, dummyToken, token, err) + cachedToken, ok := tokenFlows.cache.Get(server.URL + "/oauth2/token?client_id=09932670-9440-445d-be3e-432a97d7e2ef&grant_type=client_credentials") + assert.True(t, ok) + assert.Equal(t, dummyToken, cachedToken) } func TestClientCredentialsTokenFlow_UsingMockServer_Succeeds(t *testing.T) { @@ -127,9 +158,13 @@ func TestClientCredentialsTokenFlow_UsingMockServer_Succeeds(t *testing.T) { assertToken(t, "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo", token, err) } -func setupNewTLSServer(f func(http.ResponseWriter, *http.Request)) *httptest.Server { +func setupNewTLSServer(t *testing.T, f func(http.ResponseWriter, *http.Request)) *httptest.Server { r := mux.NewRouter() r.HandleFunc("/oauth2/token", f).Methods(http.MethodPost).Headers("Content-Type", "application/x-www-form-urlencoded") + + t.Cleanup(func() { + tokenRequestHandlerHitCounter = 0 + }) return httptest.NewTLSServer(r) } @@ -140,10 +175,11 @@ func tokenHandler(w http.ResponseWriter, r *http.Request) { newStr := buf.String() if newStr == "client_id=09932670-9440-445d-be3e-432a97d7e2ef&grant_type=client_credentials" { payload, _ := json.Marshal(tokenResponse{ - Token: "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo", + Token: dummyToken, }) _, _ = w.Write(payload) } + tokenRequestHandlerHitCounter++ } func assertToken(t assert.TestingT, expectedToken, actualToken string, actualError error) {