Skip to content

Commit

Permalink
database: Add partition key parameter to operation methods
Browse files Browse the repository at this point in the history
This parameter is not used yet, this is just to get some API
changes out of the way.

The parameter type is azcosmos.PartitionKey instead of string for
type safety. If the partition key and item ID parameters were both
strings, there would be a risk of callers mixing up the arguments.
  • Loading branch information
Matthew Barnes committed Feb 3, 2025
1 parent 3a0828e commit 3bf6678
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 32 deletions.
10 changes: 7 additions & 3 deletions backend/operations_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
ocmsdk "github.com/openshift-online/ocm-sdk-go"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
ocmerrors "github.com/openshift-online/ocm-sdk-go/errors"
Expand All @@ -32,6 +33,7 @@ const (

type operation struct {
id string
pk azcosmos.PartitionKey
doc *database.OperationDocument
logger *slog.Logger
}
Expand Down Expand Up @@ -207,6 +209,8 @@ func (s *OperationsScanner) processSubscriptions(logger *slog.Logger) {
func (s *OperationsScanner) processOperations(ctx context.Context, subscriptionID string, logger *slog.Logger) {
var numProcessed int

pk := database.NewPartitionKey(subscriptionID)

iterator := s.dbClient.ListOperationDocs(subscriptionID)

for operationID, operationDoc := range iterator.Items(ctx) {
Expand All @@ -216,7 +220,7 @@ func (s *OperationsScanner) processOperations(ctx context.Context, subscriptionI
"operation_id", operationID,
"resource_id", operationDoc.ExternalID.String(),
"internal_id", operationDoc.InternalID.String())
op := operation{operationID, operationDoc, operationLogger}
op := operation{operationID, pk, operationDoc, operationLogger}

switch operationDoc.InternalID.Kind() {
case cmv1.ClusterKind:
Expand Down Expand Up @@ -302,7 +306,7 @@ func (s *OperationsScanner) deleteOperationCompleted(ctx context.Context, op ope

// Save a final "succeeded" operation status until TTL expires.
const opStatus arm.ProvisioningState = arm.ProvisioningStateSucceeded
updated, err := s.dbClient.UpdateOperationDoc(ctx, op.id, func(updateDoc *database.OperationDocument) bool {
updated, err := s.dbClient.UpdateOperationDoc(ctx, op.pk, op.id, func(updateDoc *database.OperationDocument) bool {
return updateDoc.UpdateStatus(opStatus, nil)
})
if err != nil {
Expand All @@ -318,7 +322,7 @@ func (s *OperationsScanner) deleteOperationCompleted(ctx context.Context, op ope

// updateOperationStatus updates Cosmos DB to reflect an updated resource status.
func (s *OperationsScanner) updateOperationStatus(ctx context.Context, op operation, opStatus arm.ProvisioningState, opError *arm.CloudErrorBody) error {
updated, err := s.dbClient.UpdateOperationDoc(ctx, op.id, func(updateDoc *database.OperationDocument) bool {
updated, err := s.dbClient.UpdateOperationDoc(ctx, op.pk, op.id, func(updateDoc *database.OperationDocument) bool {
return updateDoc.UpdateStatus(opStatus, opError)
})
if err != nil {
Expand Down
10 changes: 6 additions & 4 deletions backend/operations_scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"

azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
"go.uber.org/mock/gomock"

Expand Down Expand Up @@ -88,6 +89,7 @@ func TestDeleteOperationCompleted(t *testing.T) {

op := operation{
id: "this operation",
pk: database.NewPartitionKey("00000000-0000-0000-0000-000000000000"),
doc: operationDoc,
logger: slog.Default(),
}
Expand All @@ -100,8 +102,8 @@ func TestDeleteOperationCompleted(t *testing.T) {
resourceDocDeleted = tt.resourceDocPresent
})
mockDBClient.EXPECT().
UpdateOperationDoc(gomock.Any(), op.id, gomock.Any()).
DoAndReturn(func(ctx context.Context, operationID string, callback func(*database.OperationDocument) bool) (bool, error) {
UpdateOperationDoc(gomock.Any(), op.pk, op.id, gomock.Any()).
DoAndReturn(func(ctx context.Context, pk azcosmos.PartitionKey, operationID string, callback func(*database.OperationDocument) bool) (bool, error) {
return callback(operationDoc), nil
})

Expand Down Expand Up @@ -259,8 +261,8 @@ func TestUpdateOperationStatus(t *testing.T) {
}

mockDBClient.EXPECT().
UpdateOperationDoc(gomock.Any(), op.id, gomock.Any()).
DoAndReturn(func(ctx context.Context, operationID string, callback func(*database.OperationDocument) bool) (bool, error) {
UpdateOperationDoc(gomock.Any(), op.pk, op.id, gomock.Any()).
DoAndReturn(func(ctx context.Context, pk azcosmos.PartitionKey, operationID string, callback func(*database.OperationDocument) bool) (bool, error) {
return callback(operationDoc), nil
})
mockDBClient.EXPECT().
Expand Down
12 changes: 8 additions & 4 deletions frontend/pkg/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,8 @@ func (f *Frontend) ArmResourceCreateOrUpdate(writer http.ResponseWriter, request
return
}

err = f.ExposeOperation(writer, request, operationID)
pk := database.NewPartitionKey(resourceID.SubscriptionID)
err = f.ExposeOperation(writer, request, pk, operationID)
if err != nil {
logger.Error(err.Error())
arm.WriteInternalServerError(writer)
Expand Down Expand Up @@ -637,7 +638,8 @@ func (f *Frontend) ArmResourceDelete(writer http.ResponseWriter, request *http.R
return
}

err = f.ExposeOperation(writer, request, operationID)
pk := database.NewPartitionKey(resourceID.SubscriptionID)
err = f.ExposeOperation(writer, request, pk, operationID)
if err != nil {
logger.Error(err.Error())
arm.WriteInternalServerError(writer)
Expand Down Expand Up @@ -888,7 +890,8 @@ func (f *Frontend) OperationStatus(writer http.ResponseWriter, request *http.Req
return
}

doc, err := f.dbClient.GetOperationDoc(ctx, resourceID.Name)
pk := database.NewPartitionKey(resourceID.SubscriptionID)
doc, err := f.dbClient.GetOperationDoc(ctx, pk, resourceID.Name)
if err != nil {
logger.Error(err.Error())
if errors.Is(err, database.ErrNotFound) {
Expand Down Expand Up @@ -977,7 +980,8 @@ func (f *Frontend) OperationResult(writer http.ResponseWriter, request *http.Req
return
}

doc, err := f.dbClient.GetOperationDoc(ctx, resourceID.Name)
pk := database.NewPartitionKey(resourceID.SubscriptionID)
doc, err := f.dbClient.GetOperationDoc(ctx, pk, resourceID.Name)
if err != nil {
logger.Error(err.Error())
if errors.Is(err, database.ErrNotFound) {
Expand Down
3 changes: 2 additions & 1 deletion frontend/pkg/frontend/node_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ func (f *Frontend) CreateOrUpdateNodePool(writer http.ResponseWriter, request *h
return
}

err = f.ExposeOperation(writer, request, operationID)
pk := database.NewPartitionKey(resourceID.SubscriptionID)
err = f.ExposeOperation(writer, request, pk, operationID)
if err != nil {
logger.Error(err.Error())
arm.WriteInternalServerError(writer)
Expand Down
2 changes: 1 addition & 1 deletion frontend/pkg/frontend/node_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestCreateNodePool(t *testing.T) {
CreateOperationDoc(gomock.Any(), gomock.Any())
// ExposeOperation
mockDBClient.EXPECT().
UpdateOperationDoc(gomock.Any(), gomock.Any(), gomock.Any())
UpdateOperationDoc(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
// CreateOrUpdateNodePool
mockDBClient.EXPECT().
CreateResourceDoc(gomock.Any(), gomock.Any())
Expand Down
8 changes: 5 additions & 3 deletions frontend/pkg/frontend/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"

azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"

"github.com/Azure/ARO-HCP/internal/api"
"github.com/Azure/ARO-HCP/internal/api/arm"
Expand Down Expand Up @@ -87,10 +88,10 @@ func (f *Frontend) AddLocationHeader(writer http.ResponseWriter, request *http.R

// ExposeOperation fully initiates a new asynchronous operation by enriching
// the operation database item and adding the necessary response headers.
func (f *Frontend) ExposeOperation(writer http.ResponseWriter, request *http.Request, operationID string) error {
func (f *Frontend) ExposeOperation(writer http.ResponseWriter, request *http.Request, pk azcosmos.PartitionKey, operationID string) error {
ctx := request.Context()

_, err := f.dbClient.UpdateOperationDoc(ctx, operationID, func(updateDoc *database.OperationDocument) bool {
_, err := f.dbClient.UpdateOperationDoc(ctx, pk, operationID, func(updateDoc *database.OperationDocument) bool {
// There is no way to propagate a parse error here but it should
// never fail since we are building a trusted resource ID string.
operationID, err := azcorearm.ParseResourceID(path.Join("/",
Expand Down Expand Up @@ -137,7 +138,8 @@ func (f *Frontend) ExposeOperation(writer http.ResponseWriter, request *http.Req
// CancelActiveOperation marks the status of any active operation on the resource as canceled.
func (f *Frontend) CancelActiveOperation(ctx context.Context, resourceDoc *database.ResourceDocument) error {
if resourceDoc.ActiveOperationID != "" {
updated, err := f.dbClient.UpdateOperationDoc(ctx, resourceDoc.ActiveOperationID, func(updateDoc *database.OperationDocument) bool {
pk := database.NewPartitionKey(resourceDoc.ResourceID.SubscriptionID)
updated, err := f.dbClient.UpdateOperationDoc(ctx, pk, resourceDoc.ActiveOperationID, func(updateDoc *database.OperationDocument) bool {
return updateDoc.UpdateStatus(arm.ProvisioningStateCanceled, nil)
})
// Disregard "not found" errors; a missing operation is effectively canceled.
Expand Down
16 changes: 8 additions & 8 deletions internal/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ type DBClient interface {
DeleteResourceDoc(ctx context.Context, resourceID *azcorearm.ResourceID) error
ListResourceDocs(prefix *azcorearm.ResourceID, maxItems int32, continuationToken *string) DBClientIterator[ResourceDocument]

GetOperationDoc(ctx context.Context, operationID string) (*OperationDocument, error)
GetOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string) (*OperationDocument, error)
CreateOperationDoc(ctx context.Context, doc *OperationDocument) (string, error)
UpdateOperationDoc(ctx context.Context, operationID string, callback func(*OperationDocument) bool) (bool, error)
UpdateOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string, callback func(*OperationDocument) bool) (bool, error)
ListOperationDocs(subscriptionID string) DBClientIterator[OperationDocument]

// GetSubscriptionDoc retrieves a subscription from the database given the subscriptionID.
Expand Down Expand Up @@ -327,11 +327,11 @@ func (d *cosmosDBClient) ListResourceDocs(prefix *azcorearm.ResourceID, maxItems
}
}

func (d *cosmosDBClient) getOperationDoc(ctx context.Context, operationID string) (*typedDocument[OperationDocument], *OperationDocument, error) {
func (d *cosmosDBClient) getOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string) (*typedDocument[OperationDocument], *OperationDocument, error) { //nolint:staticcheck
// Make sure lookup keys are lowercase.
operationID = strings.ToLower(operationID)

pk := NewPartitionKey(operationsPartitionKey)
pk = NewPartitionKey(operationsPartitionKey) //nolint:staticcheck

response, err := d.operations.ReadItem(ctx, pk, operationID, nil)
if err != nil {
Expand All @@ -351,8 +351,8 @@ func (d *cosmosDBClient) getOperationDoc(ctx context.Context, operationID string

// GetOperationDoc retrieves the asynchronous operation document for the given
// operation ID from the "operations" container
func (d *cosmosDBClient) GetOperationDoc(ctx context.Context, operationID string) (*OperationDocument, error) {
_, innerDoc, err := d.getOperationDoc(ctx, operationID)
func (d *cosmosDBClient) GetOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string) (*OperationDocument, error) {
_, innerDoc, err := d.getOperationDoc(ctx, pk, operationID)
return innerDoc, err
}

Expand Down Expand Up @@ -382,7 +382,7 @@ func (d *cosmosDBClient) CreateOperationDoc(ctx context.Context, doc *OperationD
// The callback function should return true if modifications were applied, signaling to proceed
// with the document replacement. The boolean return value reflects this: returning true if the
// document was successfully replaced, or false with or without an error to indicate no change.
func (d *cosmosDBClient) UpdateOperationDoc(ctx context.Context, operationID string, callback func(*OperationDocument) bool) (bool, error) {
func (d *cosmosDBClient) UpdateOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string, callback func(*OperationDocument) bool) (bool, error) {
var err error

options := &azcosmos.ItemOptions{}
Expand All @@ -392,7 +392,7 @@ func (d *cosmosDBClient) UpdateOperationDoc(ctx context.Context, operationID str
var innerDoc *OperationDocument
var data []byte

typedDoc, innerDoc, err = d.getOperationDoc(ctx, operationID)
typedDoc, innerDoc, err = d.getOperationDoc(ctx, pk, operationID)
if err != nil {
return false, err
}
Expand Down
17 changes: 9 additions & 8 deletions internal/mocks/dbclient.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3bf6678

Please sign in to comment.