Skip to content

Commit

Permalink
Remove cache from SubscriptionStateMuxValidator
Browse files Browse the repository at this point in the history
* Leverage DBClient instead
* Allows us to start deprecating this cache
* Stores the entire subscription in the request context
* Adds methods to retrieve the subscription state and tenantID from the
  request context.

Signed-off-by: Michael Shen <mshen@redhat.com>
  • Loading branch information
mjlshen committed Jun 6, 2024
1 parent b4fa654 commit a8290ad
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 58 deletions.
8 changes: 0 additions & 8 deletions frontend/pkg/frontend/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,3 @@ func (c *Cache) GetSubscription(id string) (*arm.Subscription, bool) {
subscription, found := c.subscription[id]
return subscription, found
}

func (c *Cache) SetSubscription(id string, subscription *arm.Subscription) {
c.subscription[id] = subscription
}

func (c *Cache) DeleteSubscription(id string) {
delete(c.subscription, id)
}
34 changes: 25 additions & 9 deletions frontend/pkg/frontend/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const (
contextKeyVersion
contextKeyCorrelationData
contextKeySystemData
contextKeySubscriptionState
contextKeySubscription
)

func ContextWithOriginalPath(ctx context.Context, originalPath string) context.Context {
Expand Down Expand Up @@ -123,17 +123,33 @@ func SystemDataFromContext(ctx context.Context) (*arm.SystemData, error) {
return systemData, nil
}

func ContextWithSubscriptionState(ctx context.Context, subscriptionState arm.RegistrationState) context.Context {
return context.WithValue(ctx, contextKeySubscriptionState, subscriptionState)
func ContextWithSubscription(ctx context.Context, subscription arm.Subscription) context.Context {
return context.WithValue(ctx, contextKeySubscription, subscription)
}

func SubscriptionStateFromContext(ctx context.Context) (arm.RegistrationState, error) {
subscriptionState, ok := ctx.Value(contextKeySubscriptionState).(arm.RegistrationState)
func SubscriptionFromContext(ctx context.Context) (arm.Subscription, error) {
sub, ok := ctx.Value(contextKeySubscription).(arm.Subscription)
if !ok {
err := &ContextError{
got: subscriptionState,
return arm.Subscription{}, &ContextError{
got: sub,
}
}
return sub, nil
}

func TenantIDFromContext(ctx context.Context) (string, error) {
sub, ok := ctx.Value(contextKeySubscription).(arm.Subscription)
if !ok {
return "", &ContextError{
got: sub,
}
return subscriptionState, err
}
return subscriptionState, nil

if sub.Properties == nil || sub.Properties.TenantId == nil {
return "", &ContextError{
got: sub,
}
}

return *sub.Properties.TenantId, nil
}
55 changes: 55 additions & 0 deletions frontend/pkg/frontend/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package frontend

import (
"context"
"testing"

"github.com/Azure/ARO-HCP/internal/api/arm"
)

func stringPtr(s string) *string {
return &s
}

func TestTenantIDFromContext(t *testing.T) {
tests := []struct {
name string
sub arm.Subscription
expected string
expectErr bool
}{
{
name: "Valid",
sub: arm.Subscription{
Properties: &arm.Properties{
TenantId: stringPtr("tenant-id"),
},
},
expected: "tenant-id",
expectErr: false,
},
{
name: "Missing tenantId",
sub: arm.Subscription{},
expectErr: true,
},
}

for _, test := range tests {
ctx := ContextWithSubscription(context.Background(), test.sub)
actual, err := TenantIDFromContext(ctx)
if err != nil {
if !test.expectErr {
t.Errorf("expected err to be nil, got %v", err)
}
} else {
if test.expectErr {
t.Error("expected err to be non-nil")
}

if actual != test.expected {
t.Errorf("expected %s, got %s", test.expected, actual)
}
}
}
}
3 changes: 1 addition & 2 deletions frontend/pkg/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewFrontend(logger *slog.Logger, listener net.Listener, emitter Emitter, db
region: region,
}

subscriptionStateMuxValidator := NewSubscriptionStateMuxValidator(&f.cache)
subscriptionStateMuxValidator := NewSubscriptionStateMuxValidator(f.dbClient)

// Setup metrics middleware
metricsMiddleware := MetricsMiddleware{cache: &f.cache, Emitter: emitter}
Expand Down Expand Up @@ -488,7 +488,6 @@ func (f *Frontend) ArmSubscriptionPut(writer http.ResponseWriter, request *http.
}

subscriptionID := request.PathValue(PathSegmentSubscriptionID)
f.cache.SetSubscription(subscriptionID, &subscription)

// Emit the subscription state metric
f.metrics.EmitGauge("subscription_lifecycle", 1, map[string]string{
Expand Down
23 changes: 12 additions & 11 deletions frontend/pkg/frontend/middleware_validatesubscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package frontend
import (
"net/http"

"github.com/Azure/ARO-HCP/frontend/pkg/database"
"github.com/Azure/ARO-HCP/internal/api/arm"
)

Expand All @@ -16,12 +17,12 @@ const (
)

type SubscriptionStateMuxValidator struct {
cache *Cache
dbClient database.DBClient
}

func NewSubscriptionStateMuxValidator(c *Cache) *SubscriptionStateMuxValidator {
func NewSubscriptionStateMuxValidator(dbClient database.DBClient) *SubscriptionStateMuxValidator {
return &SubscriptionStateMuxValidator{
cache: c,
dbClient: dbClient,
}
}

Expand All @@ -38,9 +39,10 @@ func (s *SubscriptionStateMuxValidator) MiddlewareValidateSubscriptionState(w ht
return
}

sub, exists := s.cache.GetSubscription(subscriptionId)

if !exists {
// TODO: Ideally, we don't want to have to hit the database in this middleware
// Currently, we are using the database to retrieve the subscription's tenantID and state
sub, err := s.dbClient.GetSubscriptionDoc(r.Context(), subscriptionId)
if err != nil {
arm.WriteError(
w, http.StatusBadRequest,
arm.CloudErrorInvalidSubscriptionState, "",
Expand All @@ -49,10 +51,9 @@ func (s *SubscriptionStateMuxValidator) MiddlewareValidateSubscriptionState(w ht
return
}

// the subscription exists, store its current state as context
ctx := ContextWithSubscriptionState(r.Context(), sub.State)
ctx := ContextWithSubscription(r.Context(), *sub.Subscription)
r = r.WithContext(ctx)
switch sub.State {
switch sub.Subscription.State {
case arm.Registered:
next(w, r)
case arm.Unregistered:
Expand All @@ -66,7 +67,7 @@ func (s *SubscriptionStateMuxValidator) MiddlewareValidateSubscriptionState(w ht
arm.WriteError(w, http.StatusConflict,
arm.CloudErrorInvalidSubscriptionState, "",
InvalidSubscriptionStateMessage,
sub.State)
sub.Subscription.State)
return
}
next(w, r)
Expand All @@ -75,6 +76,6 @@ func (s *SubscriptionStateMuxValidator) MiddlewareValidateSubscriptionState(w ht
w, http.StatusBadRequest,
arm.CloudErrorInvalidSubscriptionState, "",
InvalidSubscriptionStateMessage,
sub.State)
sub.Subscription.State)
}
}
65 changes: 37 additions & 28 deletions frontend/pkg/frontend/middleware_validatesubscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package frontend
// Licensed under the Apache License 2.0.

import (
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -15,23 +16,22 @@ import (

"github.com/google/go-cmp/cmp"

"github.com/Azure/ARO-HCP/frontend/pkg/database"
"github.com/Azure/ARO-HCP/internal/api/arm"
)

func TestMiddlewareValidateSubscription(t *testing.T) {
subscriptionId := "1234-5678"
subscriptionId := "sub-1234-5678"
tenantId := "tenant-1234-5678"
defaultRequestPath := fmt.Sprintf("subscriptions/%s/resourceGroups/xyz", subscriptionId)
cache := NewCache()
middleware := NewSubscriptionStateMuxValidator(cache)

tests := []struct {
name string
subscriptionId string
cachedState arm.RegistrationState
expectedState arm.RegistrationState
httpMethod string
requestPath string
expectedError *arm.CloudError
name string
cachedState arm.RegistrationState
expectedState arm.RegistrationState
httpMethod string
requestPath string
expectedError *arm.CloudError
}{
{
name: "subscription is already registered",
Expand Down Expand Up @@ -155,9 +155,21 @@ func TestMiddlewareValidateSubscription(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dbClient := database.NewCache()
middleware := NewSubscriptionStateMuxValidator(dbClient)

if tt.cachedState != "" {
cache.SetSubscription(subscriptionId, &arm.Subscription{State: tt.cachedState})
if err := dbClient.SetSubscriptionDoc(context.Background(), &database.SubscriptionDocument{
PartitionKey: subscriptionId,
Subscription: &arm.Subscription{
State: tt.cachedState,
Properties: &arm.Properties{
TenantId: &tenantId,
},
},
}); err != nil {
t.Fatal(err)
}
}

writer := httptest.NewRecorder()
Expand All @@ -178,25 +190,22 @@ func TestMiddlewareValidateSubscription(t *testing.T) {
}

middleware.MiddlewareValidateSubscriptionState(writer, request, next)

// clear the cache for the next test
cache.DeleteSubscription(subscriptionId)

result, err := SubscriptionStateFromContext(request.Context())
if err == nil {
if !reflect.DeepEqual(result, tt.expectedState) {
t.Error(cmp.Diff(result, tt.expectedState))
sub, err := SubscriptionFromContext(request.Context())
if err != nil {
if tt.expectedError != nil {
var actualError *arm.CloudError
body, _ := io.ReadAll(http.MaxBytesReader(writer, writer.Result().Body, 4*megabyte))
_ = json.Unmarshal(body, &actualError)
if (writer.Result().StatusCode != tt.expectedError.StatusCode) || actualError.Code != tt.expectedError.Code || actualError.Message != tt.expectedError.Message {
t.Errorf("unexpected CloudError, wanted %v, got %v", tt.expectedError, actualError)
}
} else {
t.Errorf("expected CloudError, wanted %v, got %v", tt.expectedError, err)
}
} else if tt.expectedState != "" {
t.Errorf("Expected RegistrationState %s in request context", tt.expectedState)
}
if tt.expectedError != nil {
var actualError *arm.CloudError
body, _ := io.ReadAll(http.MaxBytesReader(writer, writer.Result().Body, 4*megabyte))
_ = json.Unmarshal(body, &actualError)
if (writer.Result().StatusCode != tt.expectedError.StatusCode) || actualError.Code != tt.expectedError.Code || actualError.Message != tt.expectedError.Message {
t.Errorf("unexpected CloudError, wanted %v, got %v", tt.expectedError, actualError)
}

if !reflect.DeepEqual(sub.State, tt.expectedState) {
t.Error(cmp.Diff(sub.State, tt.expectedState))
}
})
}
Expand Down

0 comments on commit a8290ad

Please sign in to comment.