Skip to content

Commit f1e53cf

Browse files
authored
Extend the generic storage service to support custom names for names (#52826)
* Extend the generic storage service to support custom names for names * Drop the custom per-resource name as it's unused
1 parent fdc8492 commit f1e53cf

23 files changed

+96
-132
lines changed

lib/services/local/access_graph.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ type AccessGraphSecretsService struct {
4848
// SSH Keys. Future implementations might extend them.
4949
func NewAccessGraphSecretsService(b backend.Backend) (*AccessGraphSecretsService, error) {
5050
authorizedKeysSvc, err := generic.NewServiceWrapper(
51-
generic.ServiceWrapperConfig[*accessgraphsecretspb.AuthorizedKey]{
51+
generic.ServiceConfig[*accessgraphsecretspb.AuthorizedKey]{
5252
Backend: b,
5353
ResourceKind: types.KindAccessGraphSecretAuthorizedKey,
5454
BackendPrefix: backend.NewKey(authorizedKeysPrefix),
@@ -60,7 +60,7 @@ func NewAccessGraphSecretsService(b backend.Backend) (*AccessGraphSecretsService
6060
}
6161

6262
privateKeysSvc, err := generic.NewServiceWrapper(
63-
generic.ServiceWrapperConfig[*accessgraphsecretspb.PrivateKey]{
63+
generic.ServiceConfig[*accessgraphsecretspb.PrivateKey]{
6464
Backend: b,
6565
ResourceKind: types.KindAccessGraphSecretPrivateKey,
6666
BackendPrefix: backend.NewKey(privateKeysPrefix),

lib/services/local/access_list_test.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ func TestAccessListDedupeOwnersBackwardsCompat(t *testing.T) {
407407
accessListDuplicateOwners.Spec.Owners = append(accessListDuplicateOwners.Spec.Owners, accessListDuplicateOwners.Spec.Owners[0])
408408
require.Len(t, accessListDuplicateOwners.Spec.Owners, 3)
409409

410-
item, err := service.service.MakeBackendItem(accessListDuplicateOwners, accessListDuplicateOwners.GetName())
410+
item, err := service.service.MakeBackendItem(accessListDuplicateOwners)
411411
require.NoError(t, err)
412412
_, err = mem.Put(ctx, item)
413413
require.NoError(t, err)
@@ -490,7 +490,6 @@ func TestAccessListUpsertWithMembers(t *testing.T) {
490490
require.NoError(t, err)
491491
require.Empty(t, members)
492492
})
493-
494493
}
495494

496495
func TestAccessListMembersCRUD(t *testing.T) {

lib/services/local/access_monitoring_rules.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type AccessMonitoringRulesService struct {
4141
// NewAccessMonitoringRulesService creates a new AccessMonitoringRulesService.
4242
func NewAccessMonitoringRulesService(b backend.Backend) (*AccessMonitoringRulesService, error) {
4343
service, err := generic.NewServiceWrapper(
44-
generic.ServiceWrapperConfig[*accessmonitoringrulesv1.AccessMonitoringRule]{
44+
generic.ServiceConfig[*accessmonitoringrulesv1.AccessMonitoringRule]{
4545
Backend: b,
4646
ResourceKind: types.KindAccessMonitoringRule,
4747
BackendPrefix: backend.NewKey(accessMonitoringRulesPrefix),

lib/services/local/autoupdate.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -47,44 +47,44 @@ type AutoUpdateService struct {
4747
// NewAutoUpdateService returns a new AutoUpdateService.
4848
func NewAutoUpdateService(b backend.Backend) (*AutoUpdateService, error) {
4949
config, err := generic.NewServiceWrapper(
50-
generic.ServiceWrapperConfig[*autoupdate.AutoUpdateConfig]{
50+
generic.ServiceConfig[*autoupdate.AutoUpdateConfig]{
5151
Backend: b,
5252
ResourceKind: types.KindAutoUpdateConfig,
5353
BackendPrefix: backend.NewKey(autoUpdateConfigPrefix),
5454
MarshalFunc: services.MarshalProtoResource[*autoupdate.AutoUpdateConfig],
5555
UnmarshalFunc: services.UnmarshalProtoResource[*autoupdate.AutoUpdateConfig],
5656
ValidateFunc: update.ValidateAutoUpdateConfig,
57-
KeyFunc: func(*autoupdate.AutoUpdateConfig) string {
57+
NameKeyFunc: func(string) string {
5858
return types.MetaNameAutoUpdateConfig
5959
},
6060
})
6161
if err != nil {
6262
return nil, trace.Wrap(err)
6363
}
6464
version, err := generic.NewServiceWrapper(
65-
generic.ServiceWrapperConfig[*autoupdate.AutoUpdateVersion]{
65+
generic.ServiceConfig[*autoupdate.AutoUpdateVersion]{
6666
Backend: b,
6767
ResourceKind: types.KindAutoUpdateVersion,
6868
BackendPrefix: backend.NewKey(autoUpdateVersionPrefix),
6969
MarshalFunc: services.MarshalProtoResource[*autoupdate.AutoUpdateVersion],
7070
UnmarshalFunc: services.UnmarshalProtoResource[*autoupdate.AutoUpdateVersion],
7171
ValidateFunc: update.ValidateAutoUpdateVersion,
72-
KeyFunc: func(version *autoupdate.AutoUpdateVersion) string {
72+
NameKeyFunc: func(string) string {
7373
return types.MetaNameAutoUpdateVersion
7474
},
7575
})
7676
if err != nil {
7777
return nil, trace.Wrap(err)
7878
}
7979
rollout, err := generic.NewServiceWrapper(
80-
generic.ServiceWrapperConfig[*autoupdate.AutoUpdateAgentRollout]{
80+
generic.ServiceConfig[*autoupdate.AutoUpdateAgentRollout]{
8181
Backend: b,
8282
ResourceKind: types.KindAutoUpdateAgentRollout,
8383
BackendPrefix: backend.NewKey(autoUpdateAgentRolloutPrefix),
8484
MarshalFunc: services.MarshalProtoResource[*autoupdate.AutoUpdateAgentRollout],
8585
UnmarshalFunc: services.UnmarshalProtoResource[*autoupdate.AutoUpdateAgentRollout],
8686
ValidateFunc: update.ValidateAutoUpdateAgentRollout,
87-
KeyFunc: func(_ *autoupdate.AutoUpdateAgentRollout) string {
87+
NameKeyFunc: func(string) string {
8888
return types.MetaNameAutoUpdateAgentRollout
8989
},
9090
})
@@ -222,7 +222,7 @@ func itemFromAutoUpdateConfig(config *autoupdate.AutoUpdateConfig) (*backend.Ite
222222
if err != nil {
223223
return nil, trace.Wrap(err)
224224
}
225-
value, err := services.MarshalProtoResource[*autoupdate.AutoUpdateConfig](config)
225+
value, err := services.MarshalProtoResource(config)
226226
if err != nil {
227227
return nil, trace.Wrap(err)
228228
}
@@ -248,7 +248,7 @@ func itemFromAutoUpdateVersion(version *autoupdate.AutoUpdateVersion) (*backend.
248248
if err != nil {
249249
return nil, trace.Wrap(err)
250250
}
251-
value, err := services.MarshalProtoResource[*autoupdate.AutoUpdateVersion](version)
251+
value, err := services.MarshalProtoResource(version)
252252
if err != nil {
253253
return nil, trace.Wrap(err)
254254
}

lib/services/local/bot_instance.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ type BotInstanceService struct {
4545
// NewBotInstanceService creates a new BotInstanceService with the given backend.
4646
func NewBotInstanceService(b backend.Backend, clock clockwork.Clock) (*BotInstanceService, error) {
4747
service, err := generic.NewServiceWrapper(
48-
generic.ServiceWrapperConfig[*machineidv1.BotInstance]{
48+
generic.ServiceConfig[*machineidv1.BotInstance]{
4949
Backend: b,
5050
ResourceKind: types.KindBotInstance,
5151
BackendPrefix: backend.NewKey(botInstancePrefix),

lib/services/local/crown_jewels.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const crownJewelsKey = "crown_jewels"
3939
// NewCrownJewelsService creates a new CrownJewelsService.
4040
func NewCrownJewelsService(b backend.Backend) (*CrownJewelsService, error) {
4141
service, err := generic.NewServiceWrapper(
42-
generic.ServiceWrapperConfig[*crownjewelv1.CrownJewel]{
42+
generic.ServiceConfig[*crownjewelv1.CrownJewel]{
4343
Backend: b,
4444
ResourceKind: types.KindCrownJewel,
4545
BackendPrefix: backend.NewKey(crownJewelsKey),

lib/services/local/databaseobject.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ const (
7474

7575
func NewDatabaseObjectService(b backend.Backend) (*DatabaseObjectService, error) {
7676
service, err := generic.NewServiceWrapper(
77-
generic.ServiceWrapperConfig[*dbobjectv1.DatabaseObject]{
77+
generic.ServiceConfig[*dbobjectv1.DatabaseObject]{
7878
Backend: b,
7979
ResourceKind: types.KindDatabaseObject,
8080
BackendPrefix: backend.NewKey(databaseObjectPrefix),

lib/services/local/databaseobjectimportrule.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ const (
7272

7373
func NewDatabaseObjectImportRuleService(b backend.Backend) (services.DatabaseObjectImportRules, error) {
7474
service, err := generic.NewServiceWrapper(
75-
generic.ServiceWrapperConfig[*databaseobjectimportrulev1.DatabaseObjectImportRule]{
75+
generic.ServiceConfig[*databaseobjectimportrulev1.DatabaseObjectImportRule]{
7676
Backend: b,
7777
ResourceKind: types.KindDatabaseObjectImportRule,
7878
BackendPrefix: backend.NewKey(databaseObjectImportRulePrefix),

lib/services/local/generic/generic.go

+27-31
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ type MarshalFunc[T any] func(T, ...services.MarshalOption) ([]byte, error)
4444
type UnmarshalFunc[T any] func([]byte, ...services.MarshalOption) (T, error)
4545

4646
// ServiceConfig is the configuration for the service configuration.
47-
type ServiceConfig[T Resource] struct {
47+
type ServiceConfig[T any] struct {
4848
// Backend used to persist the resource.
4949
Backend backend.Backend
5050
// ResourceKind is the friendly name of the resource.
@@ -64,9 +64,10 @@ type ServiceConfig[T Resource] struct {
6464
// If set to 0, the default interval of 250ms will be used.
6565
// WARNING: If set to a negative value, the RunWhileLocked function will retry immediately.
6666
RunWhileLockedRetryInterval time.Duration
67-
// KeyFunc optionally allows resource to have a custom key. If not provided the
68-
// name of the resource will be used.
69-
KeyFunc func(T) string
67+
// NameKeyFunc optionally allows resources to have a custom key suffix, by
68+
// transforming the name of the resource or the input given to methods that
69+
// take a resource name. If unset, the name is used without changes.
70+
NameKeyFunc func(name string) string
7071
}
7172

7273
func (c *ServiceConfig[T]) CheckAndSetDefaults() error {
@@ -95,10 +96,6 @@ func (c *ServiceConfig[T]) CheckAndSetDefaults() error {
9596
c.ValidateFunc = func(t T) error { return nil }
9697
}
9798

98-
if c.KeyFunc == nil {
99-
c.KeyFunc = func(t T) string { return t.GetName() }
100-
}
101-
10299
return nil
103100
}
104101

@@ -112,7 +109,7 @@ type Service[T Resource] struct {
112109
unmarshalFunc UnmarshalFunc[T]
113110
validateFunc func(T) error
114111
runWhileLockedRetryInterval time.Duration
115-
keyFunc func(T) string
112+
nameKeyFunc func(name string) string
116113
}
117114

118115
// NewService will return a new generic service with the given config. This will
@@ -131,7 +128,7 @@ func NewService[T Resource](cfg *ServiceConfig[T]) (*Service[T], error) {
131128
unmarshalFunc: cfg.UnmarshalFunc,
132129
validateFunc: cfg.ValidateFunc,
133130
runWhileLockedRetryInterval: cfg.RunWhileLockedRetryInterval,
134-
keyFunc: cfg.KeyFunc,
131+
nameKeyFunc: cfg.NameKeyFunc,
135132
}, nil
136133
}
137134

@@ -140,18 +137,16 @@ func (s *Service[T]) WithPrefix(parts ...string) *Service[T] {
140137
if len(parts) == 0 {
141138
return s
142139
}
140+
s2 := *s
141+
s2.backendPrefix = s2.backendPrefix.AppendKey(backend.NewKey(parts...))
142+
return &s2
143+
}
143144

144-
return &Service[T]{
145-
backend: s.backend,
146-
resourceKind: s.resourceKind,
147-
pageLimit: s.pageLimit,
148-
backendPrefix: s.backendPrefix.AppendKey(backend.NewKey(parts...)),
149-
marshalFunc: s.marshalFunc,
150-
unmarshalFunc: s.unmarshalFunc,
151-
validateFunc: s.validateFunc,
152-
runWhileLockedRetryInterval: s.runWhileLockedRetryInterval,
153-
keyFunc: s.keyFunc,
145+
func (s *Service[T]) nameKey(name string) string {
146+
if s.nameKeyFunc != nil {
147+
return s.nameKeyFunc(name)
154148
}
149+
return name
155150
}
156151

157152
// CountResources will return a count of all resources in the prefix range.
@@ -204,6 +199,7 @@ func (s *Service[T]) ListResourcesReturnNextResource(ctx context.Context, pageSi
204199
resources, next, _, err := s.listResourcesReturnNextResourceWithKey(ctx, pageSize, pageToken)
205200
return resources, next, trace.Wrap(err)
206201
}
202+
207203
func (s *Service[T]) listResourcesReturnNextResourceWithKey(ctx context.Context, pageSize int, pageToken string) ([]T, *T, string, error) {
208204
rangeStart := s.backendPrefix.AppendKey(backend.KeyFromString(pageToken))
209205
rangeEnd := backend.RangeEnd(s.backendPrefix.ExactKey())
@@ -291,12 +287,11 @@ func (s *Service[T]) ListResourcesWithFilter(ctx context.Context, pageSize int,
291287
}
292288

293289
return resources, nextKey, nil
294-
295290
}
296291

297292
// GetResource returns the specified resource.
298293
func (s *Service[T]) GetResource(ctx context.Context, name string) (resource T, err error) {
299-
item, err := s.backend.Get(ctx, s.MakeKey(backend.NewKey(name)))
294+
item, err := s.backend.Get(ctx, s.MakeKey(backend.NewKey(s.nameKey(name))))
300295
if err != nil {
301296
if trace.IsNotFound(err) {
302297
return resource, trace.NotFound("%s %q doesn't exist", s.resourceKind, name)
@@ -315,7 +310,7 @@ func (s *Service[T]) CreateResource(ctx context.Context, resource T) (T, error)
315310
return t, trace.Wrap(err)
316311
}
317312

318-
item, err := s.MakeBackendItem(resource, s.keyFunc(resource))
313+
item, err := s.MakeBackendItem(resource)
319314
if err != nil {
320315
return t, trace.Wrap(err)
321316
}
@@ -340,7 +335,7 @@ func (s *Service[T]) UpdateResource(ctx context.Context, resource T) (T, error)
340335
return t, trace.Wrap(err)
341336
}
342337

343-
item, err := s.MakeBackendItem(resource, s.keyFunc(resource))
338+
item, err := s.MakeBackendItem(resource)
344339
if err != nil {
345340
return t, trace.Wrap(err)
346341
}
@@ -365,7 +360,7 @@ func (s *Service[T]) ConditionalUpdateResource(ctx context.Context, resource T)
365360
return t, trace.Wrap(err)
366361
}
367362

368-
item, err := s.MakeBackendItem(resource, s.keyFunc(resource))
363+
item, err := s.MakeBackendItem(resource)
369364
if err != nil {
370365
return t, trace.Wrap(err)
371366
}
@@ -390,7 +385,7 @@ func (s *Service[T]) UpsertResource(ctx context.Context, resource T) (T, error)
390385
return t, trace.Wrap(err)
391386
}
392387

393-
item, err := s.MakeBackendItem(resource, s.keyFunc(resource))
388+
item, err := s.MakeBackendItem(resource)
394389
if err != nil {
395390
return t, trace.Wrap(err)
396391
}
@@ -406,7 +401,7 @@ func (s *Service[T]) UpsertResource(ctx context.Context, resource T) (T, error)
406401

407402
// DeleteResource removes the specified resource.
408403
func (s *Service[T]) DeleteResource(ctx context.Context, name string) error {
409-
err := s.backend.Delete(ctx, s.MakeKey(backend.NewKey(name)))
404+
err := s.backend.Delete(ctx, s.MakeKey(backend.NewKey(s.nameKey(name))))
410405
if err != nil {
411406
if trace.IsNotFound(err) {
412407
return trace.NotFound("%s %q doesn't exist", s.resourceKind, name)
@@ -425,7 +420,7 @@ func (s *Service[T]) DeleteAllResources(ctx context.Context) error {
425420
// UpdateAndSwapResource will get the resource from the backend, modify it, and swap the new value into the backend.
426421
func (s *Service[T]) UpdateAndSwapResource(ctx context.Context, name string, modify func(T) error) (T, error) {
427422
var t T
428-
existingItem, err := s.backend.Get(ctx, s.MakeKey(backend.NewKey(name)))
423+
existingItem, err := s.backend.Get(ctx, s.MakeKey(backend.NewKey(s.nameKey(name))))
429424
if err != nil {
430425
if trace.IsNotFound(err) {
431426
return t, trace.NotFound("%s %q doesn't exist", s.resourceKind, name)
@@ -447,7 +442,7 @@ func (s *Service[T]) UpdateAndSwapResource(ctx context.Context, name string, mod
447442
return t, trace.Wrap(err)
448443
}
449444

450-
replacementItem, err := s.MakeBackendItem(resource, name)
445+
replacementItem, err := s.MakeBackendItem(resource)
451446
if err != nil {
452447
return t, trace.Wrap(err)
453448
}
@@ -462,7 +457,8 @@ func (s *Service[T]) UpdateAndSwapResource(ctx context.Context, name string, mod
462457
}
463458

464459
// MakeBackendItem will check and make the backend item.
465-
func (s *Service[T]) MakeBackendItem(resource T, name string) (backend.Item, error) {
460+
func (s *Service[T]) MakeBackendItem(resource T, _ ...any) (backend.Item, error) {
461+
// TODO(espadolini): clean up unused variadic after teleport.e is updated
466462
if err := services.CheckAndSetDefaults(resource); err != nil {
467463
return backend.Item{}, trace.Wrap(err)
468464
}
@@ -477,7 +473,7 @@ func (s *Service[T]) MakeBackendItem(resource T, name string) (backend.Item, err
477473
return backend.Item{}, trace.Wrap(err)
478474
}
479475
item := backend.Item{
480-
Key: s.MakeKey(backend.NewKey(name)),
476+
Key: s.MakeKey(backend.NewKey(s.nameKey(resource.GetName()))),
481477
Value: value,
482478
Revision: rev,
483479
}

lib/services/local/generic/generic_test.go

+19-4
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,6 @@ func TestGenericListResourcesWithFilterForScale(t *testing.T) {
540540
}
541541

542542
func TestGenericValidation(t *testing.T) {
543-
544543
ctx := context.Background()
545544

546545
memBackend, err := memory.New(memory.Config{
@@ -571,16 +570,18 @@ func TestGenericValidation(t *testing.T) {
571570

572571
_, err = service.UpsertResource(ctx, r1)
573572
require.ErrorIs(t, err, validationErr)
574-
575573
}
576574

577575
func TestGenericKeyOverride(t *testing.T) {
578-
ctx := context.Background()
576+
ctx, cancel := context.WithCancel(context.Background())
577+
defer cancel()
578+
579579
memBackend, err := memory.New(memory.Config{
580580
Context: ctx,
581581
Clock: clockwork.NewFakeClock(),
582582
})
583583
require.NoError(t, err)
584+
defer memBackend.Close()
584585

585586
service, err := NewService(&ServiceConfig[*testResource]{
586587
Backend: memBackend,
@@ -589,7 +590,7 @@ func TestGenericKeyOverride(t *testing.T) {
589590
BackendPrefix: backend.NewKey("generic_prefix"),
590591
UnmarshalFunc: unmarshalResource,
591592
MarshalFunc: marshalResource,
592-
KeyFunc: func(tr *testResource) string { return "llama" },
593+
NameKeyFunc: func(string) string { return "llama" },
593594
})
594595
require.NoError(t, err)
595596

@@ -654,4 +655,18 @@ func TestGenericKeyOverride(t *testing.T) {
654655
item, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", r1.GetName()))
655656
require.Error(t, err)
656657
require.Nil(t, item)
658+
659+
// Validate that getting the resource through the service uses the overridden name
660+
_, err = service.GetResource(ctx, r1.GetName())
661+
require.NoError(t, err)
662+
_, err = service.GetResource(ctx, "llama")
663+
require.NoError(t, err)
664+
_, err = service.GetResource(ctx, "notllama")
665+
require.NoError(t, err)
666+
667+
// Validate that deleting the resource also uses the overridden name
668+
err = service.DeleteResource(ctx, "notllama")
669+
require.NoError(t, err)
670+
_, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", "llama"))
671+
require.ErrorAs(t, err, new(*trace.NotFoundError))
657672
}

0 commit comments

Comments
 (0)