Skip to content

Commit

Permalink
Introduce cache for tokens requested from identity service (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
nenaraab authored Dec 10, 2021
1 parent 10c242f commit 4547401
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 25 deletions.
2 changes: 1 addition & 1 deletion auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestEnd2End(t *testing.T) {
name: "before validity",
header: oidcMockServer.DefaultHeaders(),
claims: mocks.NewOIDCClaimsBuilder(oidcMockServer.DefaultClaims()).
NotBefore(time.Now().Add(1 * time.Minute)).
NotBefore(time.Now().Add(2 * time.Minute)).
Build(),
wantErr: true,
}, {
Expand Down
87 changes: 74 additions & 13 deletions tokenclient/tokenFlows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/patrickmn/go-cache"
"github.com/sap/cloud-security-client-go/env"
"github.com/sap/cloud-security-client-go/httpclient"
"io/ioutil"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
)

// Options allows configuration http(s) client
Expand All @@ -31,6 +34,27 @@ type TokenFlows struct {
identity env.Identity
Options Options
tokenURI string
cache *cache.Cache
}

type request struct {
http.Request
key string
}

func (r *request) cacheKey() (string, error) {
if r.key == "" {
bodyReader, err := r.GetBody()
if err != nil {
return "", fmt.Errorf("unexpected error, can't read request body: %w", err)
}
params, err := io.ReadAll(bodyReader)
if err != nil {
return "", fmt.Errorf("unexpected error, can't read request body: %w", err)
}
r.key = fmt.Sprintf("%v?%v", r.URL, string(params))
}
return r.key, nil
}

// RequestFailedError represents a HTTP server error
Expand Down Expand Up @@ -68,6 +92,7 @@ func NewTokenFlows(identity env.Identity, options Options) (*TokenFlows, error)
identity: identity,
tokenURI: identity.GetURL() + tokenEndpoint,
Options: options,
cache: cache.New(15*time.Minute, 10*time.Minute), //nolint:gomnd
}
if options.HTTPClient == nil {
tlsConfig, err := httpclient.DefaultTLSConfig(identity)
Expand All @@ -86,7 +111,7 @@ func NewTokenFlows(identity env.Identity, options Options) (*TokenFlows, error)
//
// ctx carries the request context like the deadline or other values that should be shared across API boundaries.
// customerTenantURL like "https://custom.accounts400.ondemand.com" gives the host of the customers ias tenant
// Options allows to provide a request context and optionally additional request parameters
// options allows to provide additional request parameters
func (t *TokenFlows) ClientCredentials(ctx context.Context, customerTenantURL string, options RequestOptions) (string, error) {
data := url.Values{}
data.Set(clientIDParameter, t.identity.GetClientID())
Expand All @@ -107,14 +132,7 @@ func (t *TokenFlows) ClientCredentials(ctx context.Context, customerTenantURL st
}
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")

var tokenRes tokenResponse
err = t.performRequest(r, &tokenRes)
if err != nil {
return "", err
} else if tokenRes.Token == "" {
return "", fmt.Errorf("error parsing requested client credentials token: no 'access_token' property provided")
}
return tokenRes.Token, nil
return t.getOrRequestToken(request{Request: *r})
}

func (t *TokenFlows) getURL(customerTenantURL string) (string, error) {
Expand All @@ -128,14 +146,57 @@ func (t *TokenFlows) getURL(customerTenantURL string) (string, error) {
return "", fmt.Errorf("customer tenant url '%v' can't be parsed: %w", customerTenantURL, err)
}

func (t *TokenFlows) performRequest(r *http.Request, v interface{}) error {
res, err := t.Options.HTTPClient.Do(r)
func (t *TokenFlows) getOrRequestToken(r request) (string, error) {
// token cached?
cachedToken := t.readFromCache(&r)
if cachedToken != "" {
return cachedToken, nil
}

// request token
var tokenRes tokenResponse
err := t.performRequest(r, &tokenRes)
if err != nil {
return "", err
}
if tokenRes.Token == "" {
return "", fmt.Errorf("error parsing requested client credentials token: no 'access_token' property provided")
}

// cache and return retrieved token
t.writeToCache(r, tokenRes.Token)
return tokenRes.Token, err
}

func (t *TokenFlows) readFromCache(r *request) string {
cacheKey, err := r.cacheKey()
if err != nil {
return ""
}
cachedEncodedToken, found := t.cache.Get(cacheKey)
if !found {
return ""
}
return fmt.Sprintf("%v", cachedEncodedToken)
}

func (t *TokenFlows) writeToCache(r request, token string) {
cacheKey, err := r.cacheKey()
if err != nil {
log.Fatalf("Write to Cache is skipped. Unexpected error to determine cache key: %s", err.Error())
return
}
t.cache.SetDefault(cacheKey, token)
}

func (t *TokenFlows) performRequest(r request, v interface{}) error {
res, err := t.Options.HTTPClient.Do(&r.Request)
if err != nil {
return fmt.Errorf("request to '%v' failed: %w", r.URL, err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
body, _ := ioutil.ReadAll(res.Body)
body, _ := io.ReadAll(res.Body)
return &RequestFailedError{res.StatusCode, *r.URL, string(body)}
}
if err = json.NewDecoder(res.Body).Decode(v); err != nil {
Expand Down
58 changes: 47 additions & 11 deletions tokenclient/tokenFlows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
"time"
)

var tokenRequestHandlerHitCounter int
var dummyToken = "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo" //nolint:gosec

var clientSecretConfig = &env.DefaultIdentity{
ClientID: "09932670-9440-445d-be3e-432a97d7e2ef",
ClientSecret: "[the_CLIENT.secret:3[/abc",
Expand All @@ -35,7 +38,7 @@ func TestNewTokenFlows_setupDefaultHttpsClientFails(t *testing.T) {
}

func TestClientCredentialsTokenFlow_FailsWithTimeout(t *testing.T) {
server := setupNewTLSServer(tokenHandler)
server := setupNewTLSServer(t, tokenHandler)
defer server.Close()
tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()})

Expand All @@ -46,7 +49,7 @@ func TestClientCredentialsTokenFlow_FailsWithTimeout(t *testing.T) {
}

func TestClientCredentialsTokenFlow_FailsNoData(t *testing.T) {
server := setupNewTLSServer(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("no json")) })
server := setupNewTLSServer(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("no json")) })
defer server.Close()
tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()})

Expand All @@ -55,7 +58,7 @@ func TestClientCredentialsTokenFlow_FailsNoData(t *testing.T) {
}

func TestClientCredentialsTokenFlow_FailsNoJson(t *testing.T) {
server := setupNewTLSServer(func(w http.ResponseWriter, r *http.Request) {})
server := setupNewTLSServer(t, func(w http.ResponseWriter, r *http.Request) {})
defer server.Close()
tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()})

Expand All @@ -64,7 +67,7 @@ func TestClientCredentialsTokenFlow_FailsNoJson(t *testing.T) {
}

func TestClientCredentialsTokenFlow_FailsUnexpectedJson(t *testing.T) {
server := setupNewTLSServer(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("{\"a\":\"b\"}")) })
server := setupNewTLSServer(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("{\"a\":\"b\"}")) })
defer server.Close()
tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()})

Expand All @@ -73,9 +76,10 @@ func TestClientCredentialsTokenFlow_FailsUnexpectedJson(t *testing.T) {
}

func TestClientCredentialsTokenFlow_FailsWithUnauthenticated(t *testing.T) {
server := setupNewTLSServer(func(w http.ResponseWriter, r *http.Request) {
server := setupNewTLSServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte("unauthenticated client")) //nolint:errcheck
tokenRequestHandlerHitCounter++
})
defer server.Close()
tokenFlows, _ := NewTokenFlows(mTLSConfig, Options{HTTPClient: server.Client()})
Expand All @@ -86,10 +90,15 @@ func TestClientCredentialsTokenFlow_FailsWithUnauthenticated(t *testing.T) {
if !errors.As(err, &requestFailed) || requestFailed.StatusCode != 401 {
assert.Fail(t, "error not of type ClientError")
}
assert.Equal(t, 1, tokenRequestHandlerHitCounter)
assert.Equal(t, 0, tokenFlows.cache.ItemCount())

_, _ = tokenFlows.ClientCredentials(context.TODO(), server.URL, RequestOptions{})
assert.Equal(t, 2, tokenRequestHandlerHitCounter)
}

func TestClientCredentialsTokenFlow_FailsWithCustomerUrlWithoutScheme(t *testing.T) {
server := setupNewTLSServer(tokenHandler)
server := setupNewTLSServer(t, tokenHandler)
defer server.Close()
tokenFlows, _ := NewTokenFlows(clientSecretConfig, Options{HTTPClient: server.Client()})

Expand All @@ -99,7 +108,7 @@ func TestClientCredentialsTokenFlow_FailsWithCustomerUrlWithoutScheme(t *testing
}

func TestClientCredentialsTokenFlow_FailsWithInvalidCustomerUrl(t *testing.T) {
server := setupNewTLSServer(tokenHandler)
server := setupNewTLSServer(t, tokenHandler)
defer server.Close()
tokenFlows, _ := NewTokenFlows(clientSecretConfig, Options{HTTPClient: server.Client()})

Expand All @@ -109,12 +118,34 @@ func TestClientCredentialsTokenFlow_FailsWithInvalidCustomerUrl(t *testing.T) {
}

func TestClientCredentialsTokenFlow_Succeeds(t *testing.T) {
server := setupNewTLSServer(tokenHandler)
server := setupNewTLSServer(t, tokenHandler)
tokenFlows, _ := NewTokenFlows(&env.DefaultIdentity{
ClientID: "09932670-9440-445d-be3e-432a97d7e2ef"}, Options{HTTPClient: server.Client()})

token, err := tokenFlows.ClientCredentials(context.TODO(), server.URL, RequestOptions{})
assertToken(t, "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo", token, err)
assertToken(t, dummyToken, token, err)
}

func TestClientCredentialsTokenFlow_ReadFromCache(t *testing.T) {
server := setupNewTLSServer(t, tokenHandler)
tokenFlows, _ := NewTokenFlows(&env.DefaultIdentity{
ClientID: "09932670-9440-445d-be3e-432a97d7e2ef"}, Options{HTTPClient: server.Client()})

assert.Equal(t, 0, tokenRequestHandlerHitCounter)
assert.Equal(t, 0, tokenFlows.cache.ItemCount())

token, err := tokenFlows.ClientCredentials(context.TODO(), server.URL, RequestOptions{})
assert.Equal(t, 1, tokenRequestHandlerHitCounter)
assert.Equal(t, 1, tokenFlows.cache.ItemCount())
assertToken(t, dummyToken, token, err)

token, err = tokenFlows.ClientCredentials(context.TODO(), server.URL, RequestOptions{})
assert.Equal(t, 1, tokenRequestHandlerHitCounter)
assert.Equal(t, 1, tokenFlows.cache.ItemCount())
assertToken(t, dummyToken, token, err)
cachedToken, ok := tokenFlows.cache.Get(server.URL + "/oauth2/token?client_id=09932670-9440-445d-be3e-432a97d7e2ef&grant_type=client_credentials")
assert.True(t, ok)
assert.Equal(t, dummyToken, cachedToken)
}

func TestClientCredentialsTokenFlow_UsingMockServer_Succeeds(t *testing.T) {
Expand All @@ -127,9 +158,13 @@ func TestClientCredentialsTokenFlow_UsingMockServer_Succeeds(t *testing.T) {
assertToken(t, "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo", token, err)
}

func setupNewTLSServer(f func(http.ResponseWriter, *http.Request)) *httptest.Server {
func setupNewTLSServer(t *testing.T, f func(http.ResponseWriter, *http.Request)) *httptest.Server {
r := mux.NewRouter()
r.HandleFunc("/oauth2/token", f).Methods(http.MethodPost).Headers("Content-Type", "application/x-www-form-urlencoded")

t.Cleanup(func() {
tokenRequestHandlerHitCounter = 0
})
return httptest.NewTLSServer(r)
}

Expand All @@ -140,10 +175,11 @@ func tokenHandler(w http.ResponseWriter, r *http.Request) {
newStr := buf.String()
if newStr == "client_id=09932670-9440-445d-be3e-432a97d7e2ef&grant_type=client_credentials" {
payload, _ := json.Marshal(tokenResponse{
Token: "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo",
Token: dummyToken,
})
_, _ = w.Write(payload)
}
tokenRequestHandlerHitCounter++
}

func assertToken(t assert.TestingT, expectedToken, actualToken string, actualError error) {
Expand Down

0 comments on commit 4547401

Please sign in to comment.