diff --git a/cmd/lakefs/cmd/run.go b/cmd/lakefs/cmd/run.go index aa3dce2c442..20c0086a5a0 100644 --- a/cmd/lakefs/cmd/run.go +++ b/cmd/lakefs/cmd/run.go @@ -18,6 +18,7 @@ import ( "github.com/go-co-op/gocron" "github.com/spf13/cobra" "github.com/spf13/viper" + "github.com/treeverse/lakefs/contrib/auth/acl" "github.com/treeverse/lakefs/pkg/actions" "github.com/treeverse/lakefs/pkg/api" "github.com/treeverse/lakefs/pkg/auth" @@ -89,7 +90,7 @@ func NewAuthService(ctx context.Context, cfg *config.Config, logger logging.Logg } return auth.NewMonitoredAuthServiceAndInviter(apiService) } - authService := auth.NewAuthService( + authService := acl.NewAuthService( kvStore, crypt.NewSecretStore([]byte(cfg.Auth.Encrypt.SecretKey)), authparams.ServiceCache(cfg.Auth.Cache), diff --git a/cmd/lakefs/cmd/setup.go b/cmd/lakefs/cmd/setup.go index b15c61a6c6e..1883f2830ba 100644 --- a/cmd/lakefs/cmd/setup.go +++ b/cmd/lakefs/cmd/setup.go @@ -6,6 +6,7 @@ import ( "os" "github.com/spf13/cobra" + "github.com/treeverse/lakefs/contrib/auth/acl" "github.com/treeverse/lakefs/pkg/auth" "github.com/treeverse/lakefs/pkg/auth/crypt" "github.com/treeverse/lakefs/pkg/auth/model" @@ -74,7 +75,7 @@ var setupCmd = &cobra.Command{ defer kvStore.Close() logger := logging.ContextUnavailable() authLogger := logger.WithField("service", "auth_service") - authService = auth.NewAuthService(kvStore, crypt.NewSecretStore([]byte(cfg.Auth.Encrypt.SecretKey)), authparams.ServiceCache(cfg.Auth.Cache), authLogger) + authService = acl.NewAuthService(kvStore, crypt.NewSecretStore([]byte(cfg.Auth.Encrypt.SecretKey)), authparams.ServiceCache(cfg.Auth.Cache), authLogger) metadataManager = auth.NewKVMetadataManager(version.Version, cfg.Installation.FixedID, cfg.Database.Type, kvStore) cloudMetadataProvider := stats.BuildMetadataProvider(logger, cfg) diff --git a/cmd/lakefs/cmd/superuser.go b/cmd/lakefs/cmd/superuser.go index 5e35c991003..816d452f560 100644 --- a/cmd/lakefs/cmd/superuser.go +++ b/cmd/lakefs/cmd/superuser.go @@ -7,6 +7,7 @@ import ( "time" "github.com/spf13/cobra" + "github.com/treeverse/lakefs/contrib/auth/acl" "github.com/treeverse/lakefs/pkg/auth" "github.com/treeverse/lakefs/pkg/auth/crypt" "github.com/treeverse/lakefs/pkg/auth/model" @@ -59,7 +60,7 @@ var superuserCmd = &cobra.Command{ fmt.Printf("Failed to open KV store: %s\n", err) os.Exit(1) } - authService := auth.NewAuthService(kvStore, crypt.NewSecretStore([]byte(cfg.Auth.Encrypt.SecretKey)), authparams.ServiceCache(cfg.Auth.Cache), logger.WithField("service", "auth_service")) + authService := acl.NewAuthService(kvStore, crypt.NewSecretStore([]byte(cfg.Auth.Encrypt.SecretKey)), authparams.ServiceCache(cfg.Auth.Cache), logger.WithField("service", "auth_service")) authMetadataManager := auth.NewKVMetadataManager(version.Version, cfg.Installation.FixedID, cfg.Database.Type, kvStore) metadataProvider := stats.BuildMetadataProvider(logger, cfg) diff --git a/contrib/auth/acl/service.go b/contrib/auth/acl/service.go new file mode 100644 index 00000000000..3e19c38fba3 --- /dev/null +++ b/contrib/auth/acl/service.go @@ -0,0 +1,1000 @@ +package acl + +import ( + "context" + "errors" + "fmt" + "sort" + "strings" + "time" + + "github.com/treeverse/lakefs/pkg/auth" + "github.com/treeverse/lakefs/pkg/auth/crypt" + "github.com/treeverse/lakefs/pkg/auth/keys" + "github.com/treeverse/lakefs/pkg/auth/model" + "github.com/treeverse/lakefs/pkg/auth/params" + "github.com/treeverse/lakefs/pkg/kv" + "github.com/treeverse/lakefs/pkg/logging" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type AuthService struct { + store kv.Store + secretStore crypt.SecretStore + cache auth.Cache + log logging.Logger +} + +func NewAuthService(store kv.Store, secretStore crypt.SecretStore, cacheConf params.ServiceCache, logger logging.Logger) *AuthService { + logger.Info("initialized Auth service") + var cache auth.Cache + if cacheConf.Enabled { + cache = auth.NewLRUCache(cacheConf.Size, cacheConf.TTL, cacheConf.Jitter) + } else { + cache = &auth.DummyCache{} + } + res := &AuthService{ + store: store, + secretStore: secretStore, + cache: cache, + log: logger, + } + return res +} + +func (s *AuthService) ListKVPaged(ctx context.Context, protoType protoreflect.MessageType, params *model.PaginationParams, prefix []byte, secondary bool) ([]proto.Message, *model.Paginator, error) { + var ( + it kv.MessageIterator + err error + after []byte + ) + if params.After != "" { + after = make([]byte, len(prefix)+len(params.After)) + + l := copy(after, prefix) + _ = copy(after[l:], params.After) + } + if secondary { + it, err = kv.NewSecondaryIterator(ctx, s.store, protoType, model.PartitionKey, prefix, after) + } else { + it, err = kv.NewPrimaryIterator(ctx, s.store, protoType, model.PartitionKey, prefix, kv.IteratorOptionsAfter(after)) + } + if err != nil { + return nil, nil, fmt.Errorf("scan prefix(%s): %w", prefix, err) + } + defer it.Close() + + amount := auth.MaxPage + if params.Amount >= 0 && params.Amount < auth.MaxPage { + amount = params.Amount + } + + entries := make([]proto.Message, 0) + p := &model.Paginator{} + for len(entries) < amount && it.Next() { + entry := it.Entry() + // skip nil entries (deleted), kv can hold nil values + if entry == nil { + continue + } + entries = append(entries, entry.Value) + if len(entries) == amount { + p.NextPageToken = strings.TrimPrefix(string(entry.Key), string(prefix)) + break + } + } + if err = it.Err(); err != nil { + return nil, nil, fmt.Errorf("list DB: %w", err) + } + p.Amount = len(entries) + return entries, p, nil +} + +func (s *AuthService) SecretStore() crypt.SecretStore { + return s.secretStore +} + +func (s *AuthService) Cache() auth.Cache { + return s.cache +} + +func (s *AuthService) CreateUser(ctx context.Context, user *model.User) (string, error) { + if err := model.ValidateAuthEntityID(user.Username); err != nil { + return auth.InvalidUserID, err + } + userKey := model.UserPath(user.Username) + + err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, userKey, model.ProtoFromUser(user), nil) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + err = auth.ErrAlreadyExists + } + return "", fmt.Errorf("save user (auth.UserKey %s): %w", userKey, err) + } + return user.Username, err +} + +func (s *AuthService) DeleteUser(ctx context.Context, username string) error { + if _, err := s.GetUser(ctx, username); err != nil { + return err + } + userPath := model.UserPath(username) + + // delete policy attached to user + policiesKey := model.UserPolicyPath(username, "") + it, err := kv.NewSecondaryIterator(ctx, s.store, (&model.PolicyData{}).ProtoReflect().Type(), model.PartitionKey, policiesKey, []byte("")) + if err != nil { + return err + } + defer it.Close() + for it.Next() { + entry := it.Entry() + policy := entry.Value.(*model.PolicyData) + if err = s.DetachPolicyFromUserNoValidation(ctx, policy.DisplayName, username); err != nil { + return err + } + } + if err = it.Err(); err != nil { + return err + } + + // delete user membership of group + groupKey := model.GroupPath("") + itr, err := kv.NewPrimaryIterator(ctx, s.store, (&model.GroupData{}).ProtoReflect().Type(), model.PartitionKey, groupKey, kv.IteratorOptionsAfter([]byte(""))) + if err != nil { + return err + } + defer itr.Close() + for itr.Next() { + entry := itr.Entry() + group := entry.Value.(*model.GroupData) + if err = s.removeUserFromGroupNoValidation(ctx, username, group.DisplayName); err != nil { + return err + } + } + if err = itr.Err(); err != nil { + return err + } + + // delete user + err = s.store.Delete(ctx, []byte(model.PartitionKey), userPath) + if err != nil { + return fmt.Errorf("delete user (auth.UserKey %s): %w", userPath, err) + } + return err +} + +type UserPredicate func(u *model.UserData) bool + +func (s *AuthService) getUserByPredicate(ctx context.Context, key auth.UserKey, predicate UserPredicate) (*model.User, error) { + return s.cache.GetUser(key, func() (*model.User, error) { + m := &model.UserData{} + itr, err := kv.NewPrimaryIterator(ctx, s.store, m.ProtoReflect().Type(), model.PartitionKey, model.UserPath(""), kv.IteratorOptionsAfter([]byte(""))) + if err != nil { + return nil, fmt.Errorf("scan users: %w", err) + } + defer itr.Close() + for itr.Next() { + entry := itr.Entry() + value, ok := entry.Value.(*model.UserData) + if !ok { + return nil, fmt.Errorf("failed to cast: %w", err) + } + if predicate(value) { + return model.UserFromProto(value), nil + } + } + if itr.Err() != nil { + return nil, itr.Err() + } + return nil, auth.ErrNotFound + }) +} + +func (s *AuthService) GetUserByID(ctx context.Context, userID string) (*model.User, error) { + return s.GetUser(ctx, userID) +} + +func (s *AuthService) GetUser(ctx context.Context, username string) (*model.User, error) { + return s.cache.GetUser(auth.UserKey{Username: username}, func() (*model.User, error) { + userKey := model.UserPath(username) + m := model.UserData{} + _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, userKey, &m) + if err != nil { + if errors.Is(err, kv.ErrNotFound) { + err = auth.ErrNotFound + } + return nil, fmt.Errorf("%s: %w", username, err) + } + return model.UserFromProto(&m), nil + }) +} + +func (s *AuthService) GetUserByEmail(ctx context.Context, email string) (*model.User, error) { + return s.getUserByPredicate(ctx, auth.UserKey{Email: email}, func(value *model.UserData) bool { + return value.Email == email + }) +} + +func (s *AuthService) GetUserByExternalID(ctx context.Context, externalID string) (*model.User, error) { + return s.getUserByPredicate(ctx, auth.UserKey{ExternalID: externalID}, func(value *model.UserData) bool { + return value.ExternalId == externalID + }) +} + +func (s *AuthService) ListUsers(ctx context.Context, params *model.PaginationParams) ([]*model.User, *model.Paginator, error) { + var user model.UserData + usersKey := model.UserPath(params.Prefix) + + msgs, paginator, err := s.ListKVPaged(ctx, (&user).ProtoReflect().Type(), params, usersKey, false) + if msgs == nil { + return nil, paginator, err + } + return model.ConvertUsersDataList(msgs), paginator, err +} + +func (s *AuthService) UpdateUserFriendlyName(_ context.Context, _ string, _ string) error { + return auth.ErrNotImplemented +} + +func (s *AuthService) ListUserCredentials(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Credential, *model.Paginator, error) { + var credential model.CredentialData + credentialsKey := model.CredentialPath(username, params.Prefix) + msgs, paginator, err := s.ListKVPaged(ctx, (&credential).ProtoReflect().Type(), params, credentialsKey, false) + if err != nil { + return nil, nil, err + } + creds, err := model.ConvertCredDataList(s.secretStore, msgs) + if err != nil { + return nil, nil, err + } + return creds, paginator, nil +} + +func (s *AuthService) AttachPolicyToUser(ctx context.Context, policyDisplayName string, username string) error { + if _, err := s.GetUser(ctx, username); err != nil { + return err + } + if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { + return err + } + + policyKey := model.PolicyPath(policyDisplayName) + pu := model.UserPolicyPath(username, policyDisplayName) + + err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, pu, &kv.SecondaryIndex{PrimaryKey: policyKey}, nil) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + err = auth.ErrAlreadyExists + } + return fmt.Errorf("policy attachment to user: (key %s): %w", pu, err) + } + return nil +} + +func (s *AuthService) DetachPolicyFromUserNoValidation(ctx context.Context, policyDisplayName, username string) error { + pu := model.UserPolicyPath(username, policyDisplayName) + err := s.store.Delete(ctx, []byte(model.PartitionKey), pu) + if err != nil { + return fmt.Errorf("detaching policy: (key %s): %w", pu, err) + } + return nil +} + +func (s *AuthService) DetachPolicyFromUser(ctx context.Context, policyDisplayName, username string) error { + if _, err := s.GetUser(ctx, username); err != nil { + return err + } + if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { + return err + } + return s.DetachPolicyFromUserNoValidation(ctx, policyDisplayName, username) +} + +func (s *AuthService) ListUserPolicies(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { + var policy model.PolicyData + userPolicyKey := model.UserPolicyPath(username, params.Prefix) + + msgs, paginator, err := s.ListKVPaged(ctx, (&policy).ProtoReflect().Type(), params, userPolicyKey, true) + if msgs == nil { + return nil, paginator, err + } + return model.ConvertPolicyDataList(msgs), paginator, err +} + +func (s *AuthService) getEffectivePolicies(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { + if _, err := s.GetUser(ctx, username); err != nil { + return nil, nil, err + } + + hasMoreUserPolicy := true + afterUserPolicy := "" + amount := auth.MaxPage + policiesSet := make(map[string]*model.Policy) + // get policies attracted to user + for hasMoreUserPolicy { + policies, userPaginator, err := s.ListUserPolicies(ctx, username, &model.PaginationParams{ + After: afterUserPolicy, + Amount: amount, + }) + if err != nil { + return nil, nil, fmt.Errorf("list user policies: %w", err) + } + for _, policy := range policies { + policiesSet[policy.DisplayName] = policy + } + afterUserPolicy = userPaginator.NextPageToken + hasMoreUserPolicy = userPaginator.NextPageToken != "" + } + + hasMoreGroup := true + afterGroup := "" + for hasMoreGroup { + // get membership groups to user + groups, groupPaginator, err := s.ListUserGroups(ctx, username, &model.PaginationParams{ + After: afterGroup, + Amount: amount, + }) + if err != nil { + return nil, nil, err + } + for _, group := range groups { + // get policies attracted to group + hasMoreGroupPolicy := true + afterGroupPolicy := "" + for hasMoreGroupPolicy { + groupPolicies, groupPoliciesPaginator, err := s.ListGroupPolicies(ctx, group.DisplayName, &model.PaginationParams{ + After: afterGroupPolicy, + Amount: amount, + }) + if err != nil { + return nil, nil, fmt.Errorf("list group policies: %w", err) + } + for _, policy := range groupPolicies { + policiesSet[policy.DisplayName] = policy + } + afterGroupPolicy = groupPoliciesPaginator.NextPageToken + hasMoreGroupPolicy = groupPoliciesPaginator.NextPageToken != "" + } + } + afterGroup = groupPaginator.NextPageToken + hasMoreGroup = groupPaginator.NextPageToken != "" + } + + if params.Amount < 0 || params.Amount > auth.MaxPage { + params.Amount = auth.MaxPage + } + + var policiesArr []string + for k := range policiesSet { + policiesArr = append(policiesArr, k) + } + sort.Strings(policiesArr) + + var resPolicies []*model.Policy + resPaginator := model.Paginator{Amount: 0, NextPageToken: ""} + for _, p := range policiesArr { + if p > params.After { + resPolicies = append(resPolicies, policiesSet[p]) + if len(resPolicies) == params.Amount { + resPaginator.NextPageToken = p + break + } + } + } + resPaginator.Amount = len(resPolicies) + return resPolicies, &resPaginator, nil +} + +func (s *AuthService) ListEffectivePolicies(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { + return ListEffectivePolicies(ctx, username, params, s.getEffectivePolicies, s.cache) +} + +type effectivePoliciesGetter func(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) + +func ListEffectivePolicies(ctx context.Context, username string, params *model.PaginationParams, getEffectivePolicies effectivePoliciesGetter, cache auth.Cache) ([]*model.Policy, *model.Paginator, error) { + if params.Amount == -1 { + // read through the cache when requesting the full list + policies, err := cache.GetUserPolicies(username, func() ([]*model.Policy, error) { + policies, _, err := getEffectivePolicies(ctx, username, params) + return policies, err + }) + if err != nil { + return nil, nil, err + } + return policies, &model.Paginator{Amount: len(policies)}, nil + } + + return getEffectivePolicies(ctx, username, params) +} + +func (s *AuthService) ListGroupPolicies(ctx context.Context, groupDisplayName string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { + var policy model.PolicyData + groupPolicyKey := model.GroupPolicyPath(groupDisplayName, params.Prefix) + + msgs, paginator, err := s.ListKVPaged(ctx, (&policy).ProtoReflect().Type(), params, groupPolicyKey, true) + if msgs == nil { + return nil, paginator, err + } + return model.ConvertPolicyDataList(msgs), paginator, err +} + +func (s *AuthService) CreateGroup(ctx context.Context, group *model.Group) (*model.Group, error) { + if err := model.ValidateAuthEntityID(group.DisplayName); err != nil { + return nil, err + } + + groupKey := model.GroupPath(group.DisplayName) + err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, groupKey, model.ProtoFromGroup(group), nil) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + err = auth.ErrAlreadyExists + } + return nil, fmt.Errorf("save group (groupKey %s): %w", groupKey, err) + } + retGroup := &model.Group{ + DisplayName: group.DisplayName, + ID: group.DisplayName, + CreatedAt: group.CreatedAt, + } + return retGroup, nil +} + +func (s *AuthService) DeleteGroup(ctx context.Context, groupID string) error { + if _, err := s.GetGroup(ctx, groupID); err != nil { + return err + } + + // delete user membership to group + usersKey := model.GroupUserPath(groupID, "") + it, err := kv.NewSecondaryIterator(ctx, s.store, (&model.UserData{}).ProtoReflect().Type(), model.PartitionKey, usersKey, []byte("")) + if err != nil { + return err + } + defer it.Close() + for it.Next() { + entry := it.Entry() + user := entry.Value.(*model.UserData) + if err = s.removeUserFromGroupNoValidation(ctx, user.Username, groupID); err != nil { + return err + } + } + if err = it.Err(); err != nil { + return err + } + + // delete policy attachment to group + policiesKey := model.GroupPolicyPath(groupID, "") + itr, err := kv.NewSecondaryIterator(ctx, s.store, (&model.PolicyData{}).ProtoReflect().Type(), model.PartitionKey, policiesKey, []byte("")) + if err != nil { + return err + } + defer it.Close() + for itr.Next() { + entry := itr.Entry() + policy := entry.Value.(*model.PolicyData) + if err = s.DetachPolicyFromGroupNoValidation(ctx, policy.DisplayName, groupID); err != nil { + return err + } + } + if err = itr.Err(); err != nil { + return err + } + + // delete group + groupPath := model.GroupPath(groupID) + err = s.store.Delete(ctx, []byte(model.PartitionKey), groupPath) + if err != nil { + return fmt.Errorf("delete user (auth.UserKey %s): %w", groupPath, err) + } + return nil +} + +func (s *AuthService) GetGroup(ctx context.Context, groupID string) (*model.Group, error) { + groupKey := model.GroupPath(groupID) + m := model.GroupData{} + _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, groupKey, &m) + if err != nil { + if errors.Is(err, kv.ErrNotFound) { + err = auth.ErrNotFound + } + return nil, fmt.Errorf("%s: %w", groupID, err) + } + return model.GroupFromProto(&m), nil +} + +func (s *AuthService) ListGroups(ctx context.Context, params *model.PaginationParams) ([]*model.Group, *model.Paginator, error) { + var group model.GroupData + groupKey := model.GroupPath(params.Prefix) + + msgs, paginator, err := s.ListKVPaged(ctx, (&group).ProtoReflect().Type(), params, groupKey, false) + if msgs == nil { + return nil, paginator, err + } + return model.ConvertGroupDataList(msgs), paginator, err +} + +func (s *AuthService) AddUserToGroup(ctx context.Context, username, groupDisplayName string) error { + if _, err := s.GetUser(ctx, username); err != nil { + return err + } + if _, err := s.GetGroup(ctx, groupDisplayName); err != nil { + return err + } + + userKey := model.UserPath(username) + gu := model.GroupUserPath(groupDisplayName, username) + err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, gu, &kv.SecondaryIndex{PrimaryKey: userKey}, nil) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + err = auth.ErrAlreadyExists + } + return fmt.Errorf("add user to group: (key %s): %w", gu, err) + } + return nil +} + +func (s *AuthService) removeUserFromGroupNoValidation(ctx context.Context, username, groupID string) error { + gu := model.GroupUserPath(groupID, username) + err := s.store.Delete(ctx, []byte(model.PartitionKey), gu) + if err != nil { + return fmt.Errorf("remove user from group: (key %s): %w", gu, err) + } + return nil +} + +func (s *AuthService) RemoveUserFromGroup(ctx context.Context, username, groupID string) error { + if _, err := s.GetUser(ctx, username); err != nil { + return err + } + if _, err := s.GetGroup(ctx, groupID); err != nil { + return err + } + return s.removeUserFromGroupNoValidation(ctx, username, groupID) +} + +func (s *AuthService) ListUserGroups(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Group, *model.Paginator, error) { + if _, err := s.GetUser(ctx, username); err != nil { + return nil, nil, err + } + if params.Amount < 0 || params.Amount > auth.MaxPage { + params.Amount = auth.MaxPage + } + + hasMoreGroups := true + afterGroup := params.After + var userGroups []*model.Group + resPaginator := model.Paginator{Amount: 0, NextPageToken: ""} + for hasMoreGroups && len(userGroups) <= params.Amount { + groups, paginator, err := s.ListGroups(ctx, &model.PaginationParams{Prefix: params.Prefix, After: afterGroup, Amount: auth.MaxPage}) + if err != nil { + return nil, nil, err + } + for _, group := range groups { + path := model.GroupUserPath(group.DisplayName, username) + m := kv.SecondaryIndex{} + _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, path, &m) + if err != nil && !errors.Is(err, kv.ErrNotFound) { + return nil, nil, err + } + if err == nil { + appendGroup := &model.Group{ + DisplayName: group.DisplayName, + ID: group.DisplayName, + CreatedAt: group.CreatedAt, + } + userGroups = append(userGroups, appendGroup) + } + if len(userGroups) == params.Amount { + resPaginator.NextPageToken = group.DisplayName + resPaginator.Amount = len(userGroups) + return userGroups, &resPaginator, nil + } + } + hasMoreGroups = paginator.NextPageToken != "" + afterGroup = paginator.NextPageToken + } + resPaginator.Amount = len(userGroups) + return userGroups, &resPaginator, nil +} + +func (s *AuthService) ListGroupUsers(ctx context.Context, groupID string, params *model.PaginationParams) ([]*model.User, *model.Paginator, error) { + if _, err := s.GetGroup(ctx, groupID); err != nil { + return nil, nil, err + } + var policy model.UserData + userGroupKey := model.GroupUserPath(groupID, params.Prefix) + + msgs, paginator, err := s.ListKVPaged(ctx, (&policy).ProtoReflect().Type(), params, userGroupKey, true) + if msgs == nil { + return nil, paginator, err + } + return model.ConvertUsersDataList(msgs), paginator, err +} + +func ValidatePolicy(policy *model.Policy) error { + if err := model.ValidateAuthEntityID(policy.DisplayName); err != nil { + return err + } + for _, stmt := range policy.Statement { + for _, action := range stmt.Action { + if err := model.ValidateActionName(action); err != nil { + return err + } + } + if err := model.ValidateArn(stmt.Resource); err != nil { + return err + } + if err := model.ValidateStatementEffect(stmt.Effect); err != nil { + return err + } + } + return nil +} + +func (s *AuthService) WritePolicy(ctx context.Context, policy *model.Policy, update bool) error { + if err := ValidatePolicy(policy); err != nil { + return err + } + policyKey := model.PolicyPath(policy.DisplayName) + m := model.ProtoFromPolicy(policy) + + if update { // update policy only if it already exists + err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, policyKey, m, kv.PrecondConditionalExists) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + err = auth.ErrNotFound + } + return err + } + return nil + } + + // create policy only if it does not exist + err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, policyKey, m, nil) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + err = auth.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *AuthService) GetPolicy(ctx context.Context, policyDisplayName string) (*model.Policy, error) { + policyKey := model.PolicyPath(policyDisplayName) + p := model.PolicyData{} + _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, policyKey, &p) + if err != nil { + if errors.Is(err, kv.ErrNotFound) { + err = auth.ErrNotFound + } + return nil, fmt.Errorf("%s: %w", policyDisplayName, err) + } + return model.PolicyFromProto(&p), nil +} + +func (s *AuthService) DeletePolicy(ctx context.Context, policyDisplayName string) error { + if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { + return err + } + policyPath := model.PolicyPath(policyDisplayName) + + // delete policy attachment to user + usersKey := model.UserPath("") + it, err := kv.NewPrimaryIterator(ctx, s.store, (&model.UserData{}).ProtoReflect().Type(), model.PartitionKey, usersKey, kv.IteratorOptionsAfter([]byte(""))) + if err != nil { + return err + } + defer it.Close() + for it.Next() { + entry := it.Entry() + user := entry.Value.(*model.UserData) + if err = s.DetachPolicyFromUserNoValidation(ctx, policyDisplayName, user.Username); err != nil { + return err + } + } + + // delete policy attachment to group + groupKey := model.GroupPath("") + it, err = kv.NewPrimaryIterator(ctx, s.store, (&model.GroupData{}).ProtoReflect().Type(), model.PartitionKey, groupKey, kv.IteratorOptionsAfter([]byte(""))) + if err != nil { + return err + } + defer it.Close() + for it.Next() { + entry := it.Entry() + group := entry.Value.(*model.GroupData) + if err = s.DetachPolicyFromGroupNoValidation(ctx, policyDisplayName, group.DisplayName); err != nil { + return err + } + } + + // delete policy + err = s.store.Delete(ctx, []byte(model.PartitionKey), policyPath) + if err != nil { + return fmt.Errorf("delete policy (policyKey %s): %w", policyPath, err) + } + return nil +} + +func (s *AuthService) ListPolicies(ctx context.Context, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { + var policy model.PolicyData + policyKey := model.PolicyPath(params.Prefix) + + msgs, paginator, err := s.ListKVPaged(ctx, (&policy).ProtoReflect().Type(), params, policyKey, false) + if msgs == nil { + return nil, paginator, err + } + return model.ConvertPolicyDataList(msgs), paginator, err +} + +func (s *AuthService) CreateCredentials(ctx context.Context, username string) (*model.Credential, error) { + accessKeyID := keys.GenAccessKeyID() + secretAccessKey := keys.GenSecretAccessKey() + return s.AddCredentials(ctx, username, accessKeyID, secretAccessKey) +} + +func (s *AuthService) AddCredentials(ctx context.Context, username, accessKeyID, secretAccessKey string) (*model.Credential, error) { + if !IsValidAccessKeyID(accessKeyID) { + return nil, auth.ErrInvalidAccessKeyID + } + if len(secretAccessKey) == 0 { + return nil, auth.ErrInvalidSecretAccessKey + } + now := time.Now() + encryptedKey, err := model.EncryptSecret(s.secretStore, secretAccessKey) + if err != nil { + return nil, err + } + user, err := s.GetUser(ctx, username) + if err != nil { + return nil, err + } + + c := &model.Credential{ + BaseCredential: model.BaseCredential{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SecretAccessKeyEncryptedBytes: encryptedKey, + IssuedDate: now, + }, + Username: user.Username, + } + credentialsKey := model.CredentialPath(user.Username, c.AccessKeyID) + err = kv.SetMsgIf(ctx, s.store, model.PartitionKey, credentialsKey, model.ProtoFromCredential(c), nil) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + err = auth.ErrAlreadyExists + } + return nil, fmt.Errorf("save credentials (credentialsKey %s): %w", credentialsKey, err) + } + + return c, nil +} + +func IsValidAccessKeyID(key string) bool { + l := len(key) + return l >= 3 && l <= 20 +} + +func (s *AuthService) DeleteCredentials(ctx context.Context, username, accessKeyID string) error { + if _, err := s.GetUser(ctx, username); err != nil { + return err + } + if _, err := s.GetCredentials(ctx, accessKeyID); err != nil { + return err + } + + credPath := model.CredentialPath(username, accessKeyID) + err := s.store.Delete(ctx, []byte(model.PartitionKey), credPath) + if err != nil { + return fmt.Errorf("delete credentials (credentialsKey %s): %w", credPath, err) + } + return nil +} + +func (s *AuthService) AttachPolicyToGroup(ctx context.Context, policyDisplayName, groupDisplayName string) error { + if _, err := s.GetGroup(ctx, groupDisplayName); err != nil { + return err + } + if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { + return err + } + + policyKey := model.PolicyPath(policyDisplayName) + pg := model.GroupPolicyPath(groupDisplayName, policyDisplayName) + + err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, pg, &kv.SecondaryIndex{PrimaryKey: policyKey}, nil) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + err = auth.ErrAlreadyExists + } + return fmt.Errorf("policy attachment to group: (key %s): %w", pg, err) + } + return nil +} + +func (s *AuthService) DetachPolicyFromGroupNoValidation(ctx context.Context, policyDisplayName, groupDisplayName string) error { + pg := model.GroupPolicyPath(groupDisplayName, policyDisplayName) + err := s.store.Delete(ctx, []byte(model.PartitionKey), pg) + if err != nil { + return fmt.Errorf("policy detachment to group: (key %s): %w", pg, err) + } + return nil +} + +func (s *AuthService) DetachPolicyFromGroup(ctx context.Context, policyDisplayName, groupDisplayName string) error { + if _, err := s.GetGroup(ctx, groupDisplayName); err != nil { + return err + } + if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { + return err + } + return s.DetachPolicyFromGroupNoValidation(ctx, policyDisplayName, groupDisplayName) +} + +func (s *AuthService) GetCredentialsForUser(ctx context.Context, username, accessKeyID string) (*model.Credential, error) { + if _, err := s.GetUser(ctx, username); err != nil { + return nil, err + } + credentialsKey := model.CredentialPath(username, accessKeyID) + m := model.CredentialData{} + _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, credentialsKey, &m) + if err != nil { + if errors.Is(err, kv.ErrNotFound) { + err = auth.ErrNotFound + } + return nil, err + } + + c, err := model.CredentialFromProto(s.secretStore, &m) + if err != nil { + return nil, err + } + c.SecretAccessKey = "" + return c, nil +} + +func (s *AuthService) GetCredentials(ctx context.Context, accessKeyID string) (*model.Credential, error) { + return s.cache.GetCredential(accessKeyID, func() (*model.Credential, error) { + m := &model.UserData{} + itr, err := kv.NewPrimaryIterator(ctx, s.store, m.ProtoReflect().Type(), model.PartitionKey, model.UserPath(""), kv.IteratorOptionsAfter([]byte(""))) + if err != nil { + return nil, fmt.Errorf("scan users: %w", err) + } + defer itr.Close() + + for itr.Next() { + entry := itr.Entry() + user, ok := entry.Value.(*model.UserData) + if !ok { + return nil, fmt.Errorf("failed to cast: %w", err) + } + c := model.CredentialData{} + credentialsKey := model.CredentialPath(user.Username, accessKeyID) + _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, credentialsKey, &c) + if err != nil && !errors.Is(err, kv.ErrNotFound) { + return nil, err + } + if err == nil { + return model.CredentialFromProto(s.secretStore, &c) + } + } + if err = itr.Err(); err != nil { + return nil, err + } + return nil, fmt.Errorf("credentials %w", auth.ErrNotFound) + }) +} + +func (s *AuthService) Authorize(ctx context.Context, req *auth.AuthorizationRequest) (*auth.AuthorizationResponse, error) { + policies, _, err := s.ListEffectivePolicies(ctx, req.Username, &model.PaginationParams{ + After: "", // all + Amount: -1, // all + }) + if err != nil { + return nil, err + } + + allowed := auth.CheckPermissions(ctx, req.RequiredPermissions, req.Username, policies) + + if allowed != auth.CheckAllow { + return &auth.AuthorizationResponse{ + Allowed: false, + Error: auth.ErrInsufficientPermissions, + }, nil + } + + // we're allowed! + return &auth.AuthorizationResponse{Allowed: true}, nil +} + +func (s *AuthService) ClaimTokenIDOnce(ctx context.Context, tokenID string, expiresAt int64) error { + return claimTokenIDOnce(ctx, tokenID, expiresAt, s.markTokenSingleUse) +} + +func claimTokenIDOnce(ctx context.Context, tokenID string, expiresAt int64, markTokenSingleUse func(context.Context, string, time.Time) (bool, error)) error { + tokenExpiresAt := time.Unix(expiresAt, 0) + canUseToken, err := markTokenSingleUse(ctx, tokenID, tokenExpiresAt) + if err != nil { + return err + } + if !canUseToken { + return auth.ErrInvalidToken + } + return nil +} + +func (s *AuthService) IsExternalPrincipalsEnabled(_ context.Context) bool { + return false +} + +func (s *AuthService) CreateUserExternalPrincipal(_ context.Context, _, _ string) error { + return auth.ErrNotImplemented +} + +func (s *AuthService) DeleteUserExternalPrincipal(_ context.Context, _, _ string) error { + return auth.ErrNotImplemented +} + +func (s *AuthService) GetExternalPrincipal(_ context.Context, _ string) (*model.ExternalPrincipal, error) { + return nil, auth.ErrNotImplemented +} + +func (s *AuthService) ListUserExternalPrincipals(_ context.Context, _ string, _ *model.PaginationParams) ([]*model.ExternalPrincipal, *model.Paginator, error) { + return nil, nil, auth.ErrNotImplemented +} + +// markTokenSingleUse returns true if token is valid for single use +func (s *AuthService) markTokenSingleUse(ctx context.Context, tokenID string, tokenExpiresAt time.Time) (bool, error) { + tokenPath := model.ExpiredTokenPath(tokenID) + m := model.TokenData{TokenId: tokenID, ExpiredAt: timestamppb.New(tokenExpiresAt)} + err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, tokenPath, &m, nil) + if err != nil { + if errors.Is(err, kv.ErrPredicateFailed) { + return false, nil + } + return false, err + } + + if err := s.deleteTokens(ctx); err != nil { + s.log.WithError(err).Error("Failed to delete expired tokens") + } + return true, nil +} + +func (s *AuthService) deleteTokens(ctx context.Context) error { + it, err := kv.NewPrimaryIterator(ctx, s.store, (&model.TokenData{}).ProtoReflect().Type(), model.PartitionKey, model.ExpiredTokensPath(), kv.IteratorOptionsFrom([]byte(""))) + if err != nil { + return err + } + defer it.Close() + + deletionCutoff := time.Now() + for it.Next() { + msg := it.Entry() + if msg == nil { + return fmt.Errorf("nil token: %w", auth.ErrInvalidToken) + } + token, ok := msg.Value.(*model.TokenData) + if token == nil || !ok { + return fmt.Errorf("wrong token type: %w", auth.ErrInvalidToken) + } + + if token.ExpiredAt.AsTime().After(deletionCutoff) { + // reached a token with expiry greater than the cutoff, + // tokens are k-ordered (xid) hence we'll not find more expired tokens + return nil + } + + tokenPath := model.ExpiredTokenPath(token.TokenId) + if err := s.store.Delete(ctx, []byte(model.PartitionKey), tokenPath); err != nil { + return fmt.Errorf("deleting token: %w", err) + } + } + + return it.Err() +} diff --git a/contrib/auth/acl/service_test.go b/contrib/auth/acl/service_test.go new file mode 100644 index 00000000000..7b42cf6ec89 --- /dev/null +++ b/contrib/auth/acl/service_test.go @@ -0,0 +1,667 @@ +package acl_test + +import ( + "context" + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + authacl "github.com/treeverse/lakefs/contrib/auth/acl" + authtestutil "github.com/treeverse/lakefs/contrib/auth/acl/testutil" + "github.com/treeverse/lakefs/pkg/auth" + "github.com/treeverse/lakefs/pkg/auth/acl" + "github.com/treeverse/lakefs/pkg/auth/crypt" + "github.com/treeverse/lakefs/pkg/auth/model" + authparams "github.com/treeverse/lakefs/pkg/auth/params" + "github.com/treeverse/lakefs/pkg/kv/kvtest" + "github.com/treeverse/lakefs/pkg/logging" + "github.com/treeverse/lakefs/pkg/permissions" +) + +const creationDate = 12345678 + +var ( + someSecret = []byte("some secret") + + userPoliciesForTesting = []*model.Policy{ + { + Statement: model.Statements{ + { + Action: []string{"auth:DeleteUser"}, + Resource: "arn:lakefs:auth:::user/foobar", + Effect: model.StatementEffectAllow, + }, + { + Action: []string{"auth:*"}, + Resource: "*", + Effect: model.StatementEffectDeny, + }, + }, + }, + } +) + +func userWithPolicies(t testing.TB, s auth.Service, policies []*model.Policy) string { + t.Helper() + ctx := context.Background() + userName := uuid.New().String() + _, err := s.CreateUser(ctx, &model.User{ + Username: userName, + }) + if err != nil { + t.Fatal(err) + } + for _, policy := range policies { + if policy.DisplayName == "" { + policy.DisplayName = model.CreateID() + } + err := s.WritePolicy(ctx, policy, false) + if err != nil { + t.Fatal(err) + } + err = s.AttachPolicyToUser(ctx, policy.DisplayName, userName) + if err != nil { + t.Fatal(err) + } + } + + return userName +} + +func userWithACLs(t testing.TB, s auth.Service, a model.ACL) string { + t.Helper() + statements, err := acl.ACLToStatement(a) + if err != nil { + t.Fatal("ACLToStatement: ", err) + } + creationTime := time.Unix(creationDate, 0) + + policy := &model.Policy{ + CreatedAt: creationTime, + DisplayName: model.CreateID(), + Statement: statements, + ACL: a, + } + return userWithPolicies(t, s, []*model.Policy{policy}) +} + +// createInitialDataSet - +// Creates K users with 2 credentials each, L groups and M policies +// Add all users to all groups +// Attach M/2 of the policies to all K users and the other M-M/2 policies to all L groups +func createInitialDataSet(t *testing.T, ctx context.Context, svc auth.Service, userNames, groupNames, policyNames []string) { + for _, userName := range userNames { + if _, err := svc.CreateUser(ctx, &model.User{Username: userName}); err != nil { + t.Fatalf("CreateUser(%s): %s", userName, err) + } + for i := 0; i < 2; i++ { + _, err := svc.CreateCredentials(ctx, userName) + if err != nil { + t.Errorf("CreateCredentials(%s): %s", userName, err) + } + } + } + + for _, groupName := range groupNames { + if _, err := svc.CreateGroup(ctx, &model.Group{DisplayName: groupName}); err != nil { + t.Fatalf("CreateGroup(%s): %s", groupName, err) + } + for _, userName := range userNames { + if err := svc.AddUserToGroup(ctx, userName, groupName); err != nil { + t.Fatalf("AddUserToGroup(%s, %s): %s", userName, groupName, err) + } + } + } + + numPolicies := len(policyNames) + for i, policyName := range policyNames { + if err := svc.WritePolicy(ctx, &model.Policy{DisplayName: policyName, Statement: userPoliciesForTesting[0].Statement}, false); err != nil { + t.Fatalf("WritePolicy(%s): %s", policyName, err) + } + if i < numPolicies/2 { + for _, userName := range userNames { + if err := svc.AttachPolicyToUser(ctx, policyName, userName); err != nil { + t.Fatalf("AttachPolicyToUser(%s, %s): %s", policyName, userName, err) + } + } + } else { + for _, groupName := range groupNames { + if err := svc.AttachPolicyToGroup(ctx, policyName, groupName); err != nil { + t.Fatalf("AttachPolicyToGroup(%s, %s): %s", policyName, groupName, err) + } + } + } + } +} + +func describeAllowed(allowed bool) string { + if allowed { + return "allowed" + } + return "forbidden" +} + +func TestAuthService_ListUsers_PagedWithPrefix(t *testing.T) { + ctx := context.Background() + kvStore := kvtest.GetStore(ctx, t) + s := authacl.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ + Enabled: false, + }, logging.ContextUnavailable()) + + users := []string{"bar", "barn", "baz", "foo", "foobar", "foobaz"} + for _, u := range users { + user := model.User{Username: u} + if _, err := s.CreateUser(ctx, &user); err != nil { + t.Fatalf("create user: %s", err) + } + } + + sizes := []int{10, 3, 2} + prefixes := []string{"b", "ba", "bar", "f", "foo", "foob", "foobar"} + for _, size := range sizes { + for _, p := range prefixes { + t.Run(fmt.Sprintf("Size:%d;Prefix:%s", size, p), func(t *testing.T) { + // Only count the correct number of entries were + // returned; values are tested below. + got := 0 + after := "" + for { + value, paginator, err := s.ListUsers(ctx, &model.PaginationParams{Amount: size, Prefix: p, After: after}) + if err != nil { + t.Fatal(err) + } + got += len(value) + after = paginator.NextPageToken + if after == "" { + break + } + } + // Verify got the right number of users + count := 0 + for _, u := range users { + if strings.HasPrefix(u, p) { + count++ + } + } + if got != count { + t.Errorf("Got %d users when expecting %d", got, count) + } + }) + } + } +} + +func TestAuthService_ListPaged(t *testing.T) { + ctx := context.Background() + kvStore := kvtest.GetStore(ctx, t) + s := authacl.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ + Enabled: false, + }, logging.ContextUnavailable()) + + const chars = "abcdefghijklmnopqrstuvwxyz" + for _, c := range chars { + user := model.User{Username: string(c)} + if _, err := s.CreateUser(ctx, &user); err != nil { + t.Fatalf("create user: %s", err) + } + } + var userData model.UserData + + for size := 0; size <= len(chars)+1; size++ { + t.Run(fmt.Sprintf("PageSize%d", size), func(t *testing.T) { + pagination := &model.PaginationParams{Amount: size} + if size == 0 { // Overload to mean "don't paginate" + pagination.Amount = -1 + } + got := "" + for { + values, paginator, err := s.ListKVPaged(ctx, (&userData).ProtoReflect().Type(), pagination, model.UserPath(""), false) + if err != nil { + t.Errorf("ListPaged: %s", err) + break + } + if values == nil { + t.Fatalf("expected values for pagination %+v but got just paginator %+v", pagination, paginator) + } + letters := model.ConvertUsersDataList(values) + for _, c := range letters { + got = got + c.Username + } + if paginator.NextPageToken == "" { + if size > 0 && len(letters) > size { + t.Errorf("expected at most %d entries in last page but got %d", size, len(letters)) + } + break + } + if len(letters) != size { + t.Errorf("expected %d entries in page but got %d", size, len(letters)) + } + pagination.After = paginator.NextPageToken + } + if got != chars { + t.Errorf("Expected to read back \"%s\" but got \"%s\"", chars, got) + } + }) + } +} + +func BenchmarkKVAuthService_ListEffectivePolicies(b *testing.B) { + // setup user with policies for benchmark + ctx := context.Background() + kvStore := kvtest.GetStore(ctx, b) + + serviceWithoutCache := authacl.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ + Enabled: false, + }, logging.ContextUnavailable()) + serviceWithCache := authacl.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ + Enabled: true, + Size: 1024, + TTL: 20 * time.Second, + Jitter: 3 * time.Second, + }, logging.ContextUnavailable()) + serviceWithCacheLowTTL := authacl.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ + Enabled: true, + Size: 1024, + TTL: 1 * time.Millisecond, + Jitter: 1 * time.Millisecond, + }, logging.ContextUnavailable()) + userName := userWithPolicies(b, serviceWithoutCache, userPoliciesForTesting) + + b.Run("without_cache", func(b *testing.B) { + benchmarkKVListEffectivePolicies(b, serviceWithoutCache, userName) + }) + b.Run("with_cache", func(b *testing.B) { + benchmarkKVListEffectivePolicies(b, serviceWithCache, userName) + }) + b.Run("without_cache_low_ttl", func(b *testing.B) { + benchmarkKVListEffectivePolicies(b, serviceWithCacheLowTTL, userName) + }) +} + +func benchmarkKVListEffectivePolicies(b *testing.B, s *authacl.AuthService, userName string) { + b.ResetTimer() + ctx := context.Background() + for n := 0; n < b.N; n++ { + _, _, err := s.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: -1}) + if err != nil { + b.Fatal("Failed to list effective policies", err) + } + } +} + +func TestAuthService_DeleteUserWithRelations(t *testing.T) { + userNames := []string{"first", "second"} + groupNames := []string{"groupA", "groupB"} + policyNames := []string{"policy01", "policy02", "policy03", "policy04"} + + ctx := context.Background() + authService, _ := authtestutil.SetupService(t, ctx, someSecret) + + // create initial data set and verify users groups and policies are created and related as expected + createInitialDataSet(t, ctx, authService, userNames, groupNames, policyNames) + users, _, err := authService.ListUsers(ctx, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, users) + require.Equal(t, len(userNames), len(users)) + for _, userName := range userNames { + user, err := authService.GetUser(ctx, userName) + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, userName, user.Username) + + groups, _, err := authService.ListUserGroups(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, groups) + require.Equal(t, len(groupNames), len(groups)) + + policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)/2, len(policies)) + + policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames), len(policies)) + } + for _, groupName := range groupNames { + users, _, err := authService.ListGroupUsers(ctx, groupName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, users) + require.Equal(t, len(userNames), len(users)) + } + + // delete a user + err = authService.DeleteUser(ctx, userNames[0]) + require.NoError(t, err) + + // verify user does not exist + user, err := authService.GetUser(ctx, userNames[0]) + require.Error(t, err) + require.Nil(t, user) + + // verify user is removed from all lists and relations + users, _, err = authService.ListUsers(ctx, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, users) + require.Equal(t, len(userNames)-1, len(users)) + + for _, groupName := range groupNames { + users, _, err := authService.ListGroupUsers(ctx, groupName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, users) + require.Equal(t, len(userNames)-1, len(users)) + for _, user := range users { + require.NotEqual(t, userNames[0], user.Username) + } + } +} + +func TestAuthService_DeleteGroupWithRelations(t *testing.T) { + userNames := []string{"first", "second", "third"} + groupNames := []string{"groupA", "groupB", "groupC"} + policyNames := []string{"policy01", "policy02", "policy03", "policy04"} + + ctx := context.Background() + authService, _ := authtestutil.SetupService(t, ctx, someSecret) + + // create initial data set and verify users groups and policies are created and related as expected + createInitialDataSet(t, ctx, authService, userNames, groupNames, policyNames) + groups, _, err := authService.ListGroups(ctx, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, groups) + require.Equal(t, len(groupNames), len(groups)) + for _, userName := range userNames { + user, err := authService.GetUser(ctx, userName) + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, userName, user.Username) + + groups, _, err := authService.ListUserGroups(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, groups) + require.Equal(t, len(groupNames), len(groups)) + + policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)/2, len(policies)) + + policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames), len(policies)) + } + for _, groupName := range groupNames { + group, err := authService.GetGroup(ctx, groupName) + require.NoError(t, err) + require.NotNil(t, group) + require.Equal(t, groupName, group.DisplayName) + + users, _, err := authService.ListGroupUsers(ctx, groupName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, users) + require.Equal(t, len(userNames), len(users)) + + policies, _, err := authService.ListGroupPolicies(ctx, groupName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)-len(policyNames)/2, len(policies)) + } + for _, userName := range userNames { + groups, _, err := authService.ListUserGroups(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, groups) + require.Equal(t, len(groupNames), len(groups)) + } + + // delete a group + err = authService.DeleteGroup(ctx, groupNames[1]) + require.NoError(t, err) + + // verify group does not exist + group, err := authService.GetGroup(ctx, groupNames[1]) + require.Error(t, err) + require.Nil(t, group) + + // verify group is removed from all lists and relations + groups, _, err = authService.ListGroups(ctx, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, groups) + require.Equal(t, len(groupNames)-1, len(groups)) + + for _, userName := range userNames { + groups, _, err := authService.ListUserGroups(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, groups) + require.Equal(t, len(userNames)-1, len(groups)) + for _, group := range groups { + require.NotEqual(t, groupNames[1], group.DisplayName) + } + } +} + +func TestAuthService_DeletePoliciesWithRelations(t *testing.T) { + userNames := []string{"first", "second", "third"} + groupNames := []string{"groupA", "groupB", "groupC"} + policyNames := []string{"policy01", "policy02", "policy03", "policy04"} + + ctx := context.Background() + authService, _ := authtestutil.SetupService(t, ctx, someSecret) + + // create initial data set and verify users groups and policies are created and related as expected + createInitialDataSet(t, ctx, authService, userNames, groupNames, policyNames) + policies, _, err := authService.ListPolicies(ctx, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames), len(policies)) + for _, policyName := range policyNames { + policy, err := authService.GetPolicy(ctx, policyName) + require.NoError(t, err) + require.NotNil(t, policy) + require.Equal(t, policyName, policy.DisplayName) + } + + for _, groupName := range groupNames { + policies, _, err := authService.ListGroupPolicies(ctx, groupName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)-len(policyNames)/2, len(policies)) + } + for _, userName := range userNames { + policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)/2, len(policies)) + + policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames), len(policies)) + } + + // delete a user policy (beginning of the name list) + err = authService.DeletePolicy(ctx, policyNames[0]) + require.NoError(t, err) + + // verify policy does not exist + policy, err := authService.GetPolicy(ctx, policyNames[0]) + require.Error(t, err) + require.Nil(t, policy) + + // verify policy is removed from all lists and relations + policies, _, err = authService.ListPolicies(ctx, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)-1, len(policies)) + + for _, userName := range userNames { + policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)/2-1, len(policies)) + for _, policy := range policies { + require.NotEqual(t, policyNames[0], policy.DisplayName) + } + + policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)-1, len(policies)) + for _, policy := range policies { + require.NotEqual(t, policyNames[0], policy.DisplayName) + } + } + + for _, groupName := range groupNames { + policies, _, err := authService.ListGroupPolicies(ctx, groupName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)-len(policyNames)/2, len(policies)) + for _, policy := range policies { + require.NotEqual(t, policyNames[0], policy.DisplayName) + } + } + + // delete a group policy (end of the names list) + err = authService.DeletePolicy(ctx, policyNames[len(policyNames)-1]) + require.NoError(t, err) + + // verify policy does not exist + policy, err = authService.GetPolicy(ctx, policyNames[len(policyNames)-1]) + require.Error(t, err) + require.Nil(t, policy) + + // verify policy is removed from all lists and relations + policies, _, err = authService.ListPolicies(ctx, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)-2, len(policies)) + + for _, userName := range userNames { + policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)/2-1, len(policies)) + for _, policy := range policies { + require.NotEqual(t, policyNames[len(policyNames)-1], policy.DisplayName) + } + + policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)-2, len(policies)) + for _, policy := range policies { + require.NotEqual(t, policyNames[len(policyNames)-1], policy.DisplayName) + } + } + + for _, groupName := range groupNames { + policies, _, err := authService.ListGroupPolicies(ctx, groupName, &model.PaginationParams{Amount: 100}) + require.NoError(t, err) + require.NotNil(t, policies) + require.Equal(t, len(policyNames)-len(policyNames)/2-1, len(policies)) + for _, policy := range policies { + require.NotEqual(t, policyNames[len(policyNames)-1], policy.DisplayName) + } + } +} + +func TestACL(t *testing.T) { + hierarchy := []model.ACLPermission{acl.ReadPermission, acl.WritePermission, acl.SuperPermission, acl.AdminPermission} + + type PermissionFrom map[model.ACLPermission][]permissions.Permission + type TestCase struct { + // Name is an identifier for this test case. + Name string + // ACL is the ACL to test. ACL.Permission will be tested + // with each of the hierarchies. + ACL model.ACL + // PermissionFrom holds permissions that must hold starting + // at the ACLPermission key in the hierarchy. + PermissionFrom PermissionFrom + } + + tests := []TestCase{ + { + Name: "all repos", + ACL: model.ACL{}, + PermissionFrom: PermissionFrom{ + acl.ReadPermission: []permissions.Permission{ + {Action: permissions.ReadObjectAction, Resource: permissions.ObjectArn("foo", "some/path")}, + {Action: permissions.ListObjectsAction, Resource: permissions.ObjectArn("foo", "some/path")}, + {Action: permissions.ListObjectsAction, Resource: permissions.ObjectArn("quux", "")}, + {Action: permissions.CreateCredentialsAction, Resource: permissions.UserArn("${user}")}, + }, + acl.WritePermission: []permissions.Permission{ + {Action: permissions.WriteObjectAction, Resource: permissions.ObjectArn("foo", "some/path")}, + {Action: permissions.DeleteObjectAction, Resource: permissions.ObjectArn("foo", "some/path")}, + {Action: permissions.CreateBranchAction, Resource: permissions.BranchArn("foo", "twig")}, + {Action: permissions.CreateCommitAction, Resource: permissions.BranchArn("foo", "twig")}, + {Action: permissions.CreateMetaRangeAction, Resource: permissions.RepoArn("foo")}, + }, + acl.SuperPermission: []permissions.Permission{ + {Action: permissions.AttachStorageNamespaceAction, Resource: permissions.StorageNamespace("storage://bucket/path")}, + {Action: permissions.ImportFromStorageAction, Resource: permissions.StorageNamespace("storage://bucket/path")}, + {Action: permissions.ImportCancelAction, Resource: permissions.BranchArn("foo", "twig")}, + }, + acl.AdminPermission: []permissions.Permission{ + {Action: permissions.CreateUserAction, Resource: permissions.UserArn("you")}, + }, + }, + }, + } + + ctx := context.Background() + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + s, _ := authtestutil.SetupService(t, ctx, someSecret) + userID := make(map[model.ACLPermission]string, len(hierarchy)) + for _, aclPermission := range hierarchy { + tt.ACL.Permission = aclPermission + userID[aclPermission] = userWithACLs(t, s, tt.ACL) + } + tt.ACL.Permission = "" + + for from, pp := range tt.PermissionFrom { + for _, p := range pp { + t.Run(fmt.Sprintf("%+v", p), func(t *testing.T) { + n := permissions.Node{Permission: p} + allow := false + for _, aclPermission := range hierarchy { + t.Run(string(aclPermission), func(t *testing.T) { + if aclPermission == from { + allow = true + } + origResource := n.Permission.Resource + defer func() { + n.Permission.Resource = origResource + }() + n.Permission.Resource = strings.ReplaceAll(n.Permission.Resource, "${user}", userID[aclPermission]) + + r, err := s.Authorize(ctx, &auth.AuthorizationRequest{ + Username: userID[aclPermission], + RequiredPermissions: n, + }) + if err != nil { + t.Errorf("Authorize failed: %v", err) + } + if (allow && r.Error != nil) || !allow && !errors.Is(r.Error, auth.ErrInsufficientPermissions) { + t.Errorf("Authorization response error: %v", err) + } + if r.Allowed != allow { + t.Errorf("%s but expected %s", describeAllowed(r.Allowed), describeAllowed(allow)) + } + }) + } + }) + } + } + }) + } +} diff --git a/pkg/auth/testutil/service.go b/contrib/auth/acl/testutil/service.go similarity index 72% rename from pkg/auth/testutil/service.go rename to contrib/auth/acl/testutil/service.go index 245645e7fb3..1cebd2630f0 100644 --- a/pkg/auth/testutil/service.go +++ b/contrib/auth/acl/testutil/service.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/treeverse/lakefs/pkg/auth" + "github.com/treeverse/lakefs/contrib/auth/acl" "github.com/treeverse/lakefs/pkg/auth/crypt" authparams "github.com/treeverse/lakefs/pkg/auth/params" "github.com/treeverse/lakefs/pkg/kv" @@ -12,10 +12,10 @@ import ( "github.com/treeverse/lakefs/pkg/logging" ) -func SetupService(t *testing.T, ctx context.Context, secret []byte) (*auth.AuthService, kv.Store) { +func SetupService(t *testing.T, ctx context.Context, secret []byte) (*acl.AuthService, kv.Store) { t.Helper() kvStore := kvtest.GetStore(ctx, t) - return auth.NewAuthService(kvStore, crypt.NewSecretStore(secret), authparams.ServiceCache{ + return acl.NewAuthService(kvStore, crypt.NewSecretStore(secret), authparams.ServiceCache{ Enabled: false, }, logging.ContextUnavailable()), kvStore } diff --git a/esti/delete_objects_test.go b/esti/delete_objects_test.go index 782483824d1..2943788b420 100644 --- a/esti/delete_objects_test.go +++ b/esti/delete_objects_test.go @@ -35,8 +35,8 @@ func TestDeleteObjects(t *testing.T) { Prefix: aws.String(mainBranch + "/"), }) - assert.NoError(t, err) - assert.Len(t, listOut.Contents, numOfObjects) + require.NoError(t, err) + require.Len(t, listOut.Contents, numOfObjects) deleteOut, err := svc.DeleteObjects(ctx, &s3.DeleteObjectsInput{ Bucket: aws.String(repo), diff --git a/pkg/api/serve_test.go b/pkg/api/serve_test.go index db53d7c7cdf..bb69e988c22 100644 --- a/pkg/api/serve_test.go +++ b/pkg/api/serve_test.go @@ -14,6 +14,7 @@ import ( "github.com/deepmap/oapi-codegen/pkg/securityprovider" "github.com/go-openapi/swag" "github.com/spf13/viper" + "github.com/treeverse/lakefs/contrib/auth/acl" "github.com/treeverse/lakefs/pkg/actions" "github.com/treeverse/lakefs/pkg/api" "github.com/treeverse/lakefs/pkg/api/apigen" @@ -146,7 +147,7 @@ func setupHandler(t testing.TB) (http.Handler, *dependencies) { factory := store.NewFactory(nil) actionsStore := actions.NewActionsKVStore(kvStore) idGen := &actions.DecreasingIDGenerator{} - authService := auth.NewAuthService(kvStore, crypt.NewSecretStore([]byte("some secret")), authparams.ServiceCache{ + authService := acl.NewAuthService(kvStore, crypt.NewSecretStore([]byte("some secret")), authparams.ServiceCache{ Enabled: false, }, logging.ContextUnavailable()) meta := auth.NewKVMetadataManager("serve_test", cfg.Installation.FixedID, cfg.Database.Type, kvStore) diff --git a/pkg/auth/cache.go b/pkg/auth/cache.go index 5204c6cabcc..92f6dada4ea 100644 --- a/pkg/auth/cache.go +++ b/pkg/auth/cache.go @@ -11,16 +11,16 @@ type CredentialSetFn func() (*model.Credential, error) type UserSetFn func() (*model.User, error) type UserPoliciesSetFn func() ([]*model.Policy, error) -type userKey struct { +type UserKey struct { id string - username string - externalID string - email string + Username string + ExternalID string + Email string } type Cache interface { GetCredential(accessKeyID string, setFn CredentialSetFn) (*model.Credential, error) - GetUser(key userKey, setFn UserSetFn) (*model.User, error) + GetUser(key UserKey, setFn UserSetFn) (*model.User, error) GetUserPolicies(userID string, setFn UserPoliciesSetFn) ([]*model.Policy, error) } @@ -47,7 +47,7 @@ func (c *LRUCache) GetCredential(accessKeyID string, setFn CredentialSetFn) (*mo return v.(*model.Credential), nil } -func (c *LRUCache) GetUser(key userKey, setFn UserSetFn) (*model.User, error) { +func (c *LRUCache) GetUser(key UserKey, setFn UserSetFn) (*model.User, error) { v, err := c.userCache.GetOrSet(key, func() (interface{}, error) { return setFn() }) if err != nil { return nil, err @@ -70,7 +70,7 @@ func (d *DummyCache) GetCredential(_ string, setFn CredentialSetFn) (*model.Cred return setFn() } -func (d *DummyCache) GetUser(_ userKey, setFn UserSetFn) (*model.User, error) { +func (d *DummyCache) GetUser(_ UserKey, setFn UserSetFn) (*model.User, error) { return setFn() } diff --git a/pkg/auth/service.go b/pkg/auth/service.go index ebfc0941f08..d04b7bf9d0e 100644 --- a/pkg/auth/service.go +++ b/pkg/auth/service.go @@ -12,10 +12,8 @@ package auth import ( "context" - "errors" "fmt" "net/http" - "sort" "strconv" "strings" "time" @@ -28,17 +26,12 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/rs/xid" "github.com/treeverse/lakefs/pkg/auth/crypt" - "github.com/treeverse/lakefs/pkg/auth/keys" "github.com/treeverse/lakefs/pkg/auth/model" "github.com/treeverse/lakefs/pkg/auth/params" "github.com/treeverse/lakefs/pkg/auth/wildcard" "github.com/treeverse/lakefs/pkg/httputil" - "github.com/treeverse/lakefs/pkg/kv" "github.com/treeverse/lakefs/pkg/logging" "github.com/treeverse/lakefs/pkg/permissions" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/known/timestamppb" ) type AuthorizationRequest struct { @@ -56,7 +49,7 @@ type CheckResult int const ( InvalidUserID = "" - maxPage = 1000 + MaxPage = 1000 // CheckAllow Permission allowed CheckAllow CheckResult = iota // CheckNeutral Permission neither allowed nor denied @@ -158,1051 +151,6 @@ type Service interface { ClaimTokenIDOnce(ctx context.Context, tokenID string, expiresAt int64) error } -func (s *AuthService) ListKVPaged(ctx context.Context, protoType protoreflect.MessageType, params *model.PaginationParams, prefix []byte, secondary bool) ([]proto.Message, *model.Paginator, error) { - var ( - it kv.MessageIterator - err error - after []byte - ) - if params.After != "" { - after = make([]byte, len(prefix)+len(params.After)) - - l := copy(after, prefix) - _ = copy(after[l:], params.After) - } - if secondary { - it, err = kv.NewSecondaryIterator(ctx, s.store, protoType, model.PartitionKey, prefix, after) - } else { - it, err = kv.NewPrimaryIterator(ctx, s.store, protoType, model.PartitionKey, prefix, kv.IteratorOptionsAfter(after)) - } - if err != nil { - return nil, nil, fmt.Errorf("scan prefix(%s): %w", prefix, err) - } - defer it.Close() - - amount := maxPage - if params.Amount >= 0 && params.Amount < maxPage { - amount = params.Amount - } - - entries := make([]proto.Message, 0) - p := &model.Paginator{} - for len(entries) < amount && it.Next() { - entry := it.Entry() - // skip nil entries (deleted), kv can hold nil values - if entry == nil { - continue - } - entries = append(entries, entry.Value) - if len(entries) == amount { - p.NextPageToken = strings.TrimPrefix(string(entry.Key), string(prefix)) - break - } - } - if err = it.Err(); err != nil { - return nil, nil, fmt.Errorf("list DB: %w", err) - } - p.Amount = len(entries) - return entries, p, nil -} - -type AuthService struct { - store kv.Store - secretStore crypt.SecretStore - cache Cache - log logging.Logger -} - -func NewAuthService(store kv.Store, secretStore crypt.SecretStore, cacheConf params.ServiceCache, logger logging.Logger) *AuthService { - logger.Info("initialized Auth service") - var cache Cache - if cacheConf.Enabled { - cache = NewLRUCache(cacheConf.Size, cacheConf.TTL, cacheConf.Jitter) - } else { - cache = &DummyCache{} - } - res := &AuthService{ - store: store, - secretStore: secretStore, - cache: cache, - log: logger, - } - return res -} - -func (s *AuthService) SecretStore() crypt.SecretStore { - return s.secretStore -} - -func (s *AuthService) Cache() Cache { - return s.cache -} - -func (s *AuthService) CreateUser(ctx context.Context, user *model.User) (string, error) { - if err := model.ValidateAuthEntityID(user.Username); err != nil { - return InvalidUserID, err - } - userKey := model.UserPath(user.Username) - - err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, userKey, model.ProtoFromUser(user), nil) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - err = ErrAlreadyExists - } - return "", fmt.Errorf("save user (userKey %s): %w", userKey, err) - } - return user.Username, err -} - -func (s *AuthService) DeleteUser(ctx context.Context, username string) error { - if _, err := s.GetUser(ctx, username); err != nil { - return err - } - userPath := model.UserPath(username) - - // delete policy attached to user - policiesKey := model.UserPolicyPath(username, "") - it, err := kv.NewSecondaryIterator(ctx, s.store, (&model.PolicyData{}).ProtoReflect().Type(), model.PartitionKey, policiesKey, []byte("")) - if err != nil { - return err - } - defer it.Close() - for it.Next() { - entry := it.Entry() - policy := entry.Value.(*model.PolicyData) - if err = s.DetachPolicyFromUserNoValidation(ctx, policy.DisplayName, username); err != nil { - return err - } - } - if err = it.Err(); err != nil { - return err - } - - // delete user membership of group - groupKey := model.GroupPath("") - itr, err := kv.NewPrimaryIterator(ctx, s.store, (&model.GroupData{}).ProtoReflect().Type(), model.PartitionKey, groupKey, kv.IteratorOptionsAfter([]byte(""))) - if err != nil { - return err - } - defer itr.Close() - for itr.Next() { - entry := itr.Entry() - group := entry.Value.(*model.GroupData) - if err = s.removeUserFromGroupNoValidation(ctx, username, group.DisplayName); err != nil { - return err - } - } - if err = itr.Err(); err != nil { - return err - } - - // delete user - err = s.store.Delete(ctx, []byte(model.PartitionKey), userPath) - if err != nil { - return fmt.Errorf("delete user (userKey %s): %w", userPath, err) - } - return err -} - -type UserPredicate func(u *model.UserData) bool - -func (s *AuthService) getUserByPredicate(ctx context.Context, key userKey, predicate UserPredicate) (*model.User, error) { - return s.cache.GetUser(key, func() (*model.User, error) { - m := &model.UserData{} - itr, err := kv.NewPrimaryIterator(ctx, s.store, m.ProtoReflect().Type(), model.PartitionKey, model.UserPath(""), kv.IteratorOptionsAfter([]byte(""))) - if err != nil { - return nil, fmt.Errorf("scan users: %w", err) - } - defer itr.Close() - for itr.Next() { - entry := itr.Entry() - value, ok := entry.Value.(*model.UserData) - if !ok { - return nil, fmt.Errorf("failed to cast: %w", err) - } - if predicate(value) { - return model.UserFromProto(value), nil - } - } - if itr.Err() != nil { - return nil, itr.Err() - } - return nil, ErrNotFound - }) -} - -// GetUserByID TODO(niro): In KV ID == username, Remove this method when DB implementation is deleted -func (s *AuthService) GetUserByID(ctx context.Context, userID string) (*model.User, error) { - return s.GetUser(ctx, userID) -} - -func (s *AuthService) GetUser(ctx context.Context, username string) (*model.User, error) { - return s.cache.GetUser(userKey{username: username}, func() (*model.User, error) { - userKey := model.UserPath(username) - m := model.UserData{} - _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, userKey, &m) - if err != nil { - if errors.Is(err, kv.ErrNotFound) { - err = ErrNotFound - } - return nil, fmt.Errorf("%s: %w", username, err) - } - return model.UserFromProto(&m), nil - }) -} - -func (s *AuthService) GetUserByEmail(ctx context.Context, email string) (*model.User, error) { - return s.getUserByPredicate(ctx, userKey{email: email}, func(value *model.UserData) bool { - return value.Email == email - }) -} - -func (s *AuthService) GetUserByExternalID(ctx context.Context, externalID string) (*model.User, error) { - return s.getUserByPredicate(ctx, userKey{externalID: externalID}, func(value *model.UserData) bool { - return value.ExternalId == externalID - }) -} - -func (s *AuthService) ListUsers(ctx context.Context, params *model.PaginationParams) ([]*model.User, *model.Paginator, error) { - var user model.UserData - usersKey := model.UserPath(params.Prefix) - - msgs, paginator, err := s.ListKVPaged(ctx, (&user).ProtoReflect().Type(), params, usersKey, false) - if msgs == nil { - return nil, paginator, err - } - return model.ConvertUsersDataList(msgs), paginator, err -} - -func (s *AuthService) UpdateUserFriendlyName(ctx context.Context, userID string, friendlyName string) error { - return ErrNotImplemented -} - -func (s *AuthService) ListUserCredentials(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Credential, *model.Paginator, error) { - var credential model.CredentialData - credentialsKey := model.CredentialPath(username, params.Prefix) - msgs, paginator, err := s.ListKVPaged(ctx, (&credential).ProtoReflect().Type(), params, credentialsKey, false) - if err != nil { - return nil, nil, err - } - creds, err := model.ConvertCredDataList(s.secretStore, msgs) - if err != nil { - return nil, nil, err - } - return creds, paginator, nil -} - -func (s *AuthService) AttachPolicyToUser(ctx context.Context, policyDisplayName string, username string) error { - if _, err := s.GetUser(ctx, username); err != nil { - return err - } - if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { - return err - } - - policyKey := model.PolicyPath(policyDisplayName) - pu := model.UserPolicyPath(username, policyDisplayName) - - err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, pu, &kv.SecondaryIndex{PrimaryKey: policyKey}, nil) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - err = ErrAlreadyExists - } - return fmt.Errorf("policy attachment to user: (key %s): %w", pu, err) - } - return nil -} - -func (s *AuthService) DetachPolicyFromUserNoValidation(ctx context.Context, policyDisplayName, username string) error { - pu := model.UserPolicyPath(username, policyDisplayName) - err := s.store.Delete(ctx, []byte(model.PartitionKey), pu) - if err != nil { - return fmt.Errorf("detaching policy: (key %s): %w", pu, err) - } - return nil -} - -func (s *AuthService) DetachPolicyFromUser(ctx context.Context, policyDisplayName, username string) error { - if _, err := s.GetUser(ctx, username); err != nil { - return err - } - if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { - return err - } - return s.DetachPolicyFromUserNoValidation(ctx, policyDisplayName, username) -} - -func (s *AuthService) ListUserPolicies(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { - var policy model.PolicyData - userPolicyKey := model.UserPolicyPath(username, params.Prefix) - - msgs, paginator, err := s.ListKVPaged(ctx, (&policy).ProtoReflect().Type(), params, userPolicyKey, true) - if msgs == nil { - return nil, paginator, err - } - return model.ConvertPolicyDataList(msgs), paginator, err -} - -func (s *AuthService) getEffectivePolicies(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { - if _, err := s.GetUser(ctx, username); err != nil { - return nil, nil, err - } - - hasMoreUserPolicy := true - afterUserPolicy := "" - amount := maxPage - policiesSet := make(map[string]*model.Policy) - // get policies attracted to user - for hasMoreUserPolicy { - policies, userPaginator, err := s.ListUserPolicies(ctx, username, &model.PaginationParams{ - After: afterUserPolicy, - Amount: amount, - }) - if err != nil { - return nil, nil, fmt.Errorf("list user policies: %w", err) - } - for _, policy := range policies { - policiesSet[policy.DisplayName] = policy - } - afterUserPolicy = userPaginator.NextPageToken - hasMoreUserPolicy = userPaginator.NextPageToken != "" - } - - hasMoreGroup := true - afterGroup := "" - for hasMoreGroup { - // get membership groups to user - groups, groupPaginator, err := s.ListUserGroups(ctx, username, &model.PaginationParams{ - After: afterGroup, - Amount: amount, - }) - if err != nil { - return nil, nil, err - } - for _, group := range groups { - // get policies attracted to group - hasMoreGroupPolicy := true - afterGroupPolicy := "" - for hasMoreGroupPolicy { - groupPolicies, groupPoliciesPaginator, err := s.ListGroupPolicies(ctx, group.DisplayName, &model.PaginationParams{ - After: afterGroupPolicy, - Amount: amount, - }) - if err != nil { - return nil, nil, fmt.Errorf("list group policies: %w", err) - } - for _, policy := range groupPolicies { - policiesSet[policy.DisplayName] = policy - } - afterGroupPolicy = groupPoliciesPaginator.NextPageToken - hasMoreGroupPolicy = groupPoliciesPaginator.NextPageToken != "" - } - } - afterGroup = groupPaginator.NextPageToken - hasMoreGroup = groupPaginator.NextPageToken != "" - } - - if params.Amount < 0 || params.Amount > maxPage { - params.Amount = maxPage - } - - var policiesArr []string - for k := range policiesSet { - policiesArr = append(policiesArr, k) - } - sort.Strings(policiesArr) - - var resPolicies []*model.Policy - resPaginator := model.Paginator{Amount: 0, NextPageToken: ""} - for _, p := range policiesArr { - if p > params.After { - resPolicies = append(resPolicies, policiesSet[p]) - if len(resPolicies) == params.Amount { - resPaginator.NextPageToken = p - break - } - } - } - resPaginator.Amount = len(resPolicies) - return resPolicies, &resPaginator, nil -} - -func (s *AuthService) ListEffectivePolicies(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { - return ListEffectivePolicies(ctx, username, params, s.getEffectivePolicies, s.cache) -} - -type effectivePoliciesGetter func(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) - -func ListEffectivePolicies(ctx context.Context, username string, params *model.PaginationParams, getEffectivePolicies effectivePoliciesGetter, cache Cache) ([]*model.Policy, *model.Paginator, error) { - if params.Amount == -1 { - // read through the cache when requesting the full list - policies, err := cache.GetUserPolicies(username, func() ([]*model.Policy, error) { - policies, _, err := getEffectivePolicies(ctx, username, params) - return policies, err - }) - if err != nil { - return nil, nil, err - } - return policies, &model.Paginator{Amount: len(policies)}, nil - } - - return getEffectivePolicies(ctx, username, params) -} - -func (s *AuthService) ListGroupPolicies(ctx context.Context, groupDisplayName string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { - var policy model.PolicyData - groupPolicyKey := model.GroupPolicyPath(groupDisplayName, params.Prefix) - - msgs, paginator, err := s.ListKVPaged(ctx, (&policy).ProtoReflect().Type(), params, groupPolicyKey, true) - if msgs == nil { - return nil, paginator, err - } - return model.ConvertPolicyDataList(msgs), paginator, err -} - -func (s *AuthService) CreateGroup(ctx context.Context, group *model.Group) (*model.Group, error) { - if err := model.ValidateAuthEntityID(group.DisplayName); err != nil { - return nil, err - } - - groupKey := model.GroupPath(group.DisplayName) - err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, groupKey, model.ProtoFromGroup(group), nil) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - err = ErrAlreadyExists - } - return nil, fmt.Errorf("save group (groupKey %s): %w", groupKey, err) - } - retGroup := &model.Group{ - DisplayName: group.DisplayName, - ID: group.DisplayName, - CreatedAt: group.CreatedAt, - } - return retGroup, nil -} - -func (s *AuthService) DeleteGroup(ctx context.Context, groupID string) error { - if _, err := s.GetGroup(ctx, groupID); err != nil { - return err - } - - // delete user membership to group - usersKey := model.GroupUserPath(groupID, "") - it, err := kv.NewSecondaryIterator(ctx, s.store, (&model.UserData{}).ProtoReflect().Type(), model.PartitionKey, usersKey, []byte("")) - if err != nil { - return err - } - defer it.Close() - for it.Next() { - entry := it.Entry() - user := entry.Value.(*model.UserData) - if err = s.removeUserFromGroupNoValidation(ctx, user.Username, groupID); err != nil { - return err - } - } - if err = it.Err(); err != nil { - return err - } - - // delete policy attachment to group - policiesKey := model.GroupPolicyPath(groupID, "") - itr, err := kv.NewSecondaryIterator(ctx, s.store, (&model.PolicyData{}).ProtoReflect().Type(), model.PartitionKey, policiesKey, []byte("")) - if err != nil { - return err - } - defer it.Close() - for itr.Next() { - entry := itr.Entry() - policy := entry.Value.(*model.PolicyData) - if err = s.DetachPolicyFromGroupNoValidation(ctx, policy.DisplayName, groupID); err != nil { - return err - } - } - if err = itr.Err(); err != nil { - return err - } - - // delete group - groupPath := model.GroupPath(groupID) - err = s.store.Delete(ctx, []byte(model.PartitionKey), groupPath) - if err != nil { - return fmt.Errorf("delete user (userKey %s): %w", groupPath, err) - } - return nil -} - -func (s *AuthService) GetGroup(ctx context.Context, groupID string) (*model.Group, error) { - groupKey := model.GroupPath(groupID) - m := model.GroupData{} - _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, groupKey, &m) - if err != nil { - if errors.Is(err, kv.ErrNotFound) { - err = ErrNotFound - } - return nil, fmt.Errorf("%s: %w", groupID, err) - } - return model.GroupFromProto(&m), nil -} - -func (s *AuthService) ListGroups(ctx context.Context, params *model.PaginationParams) ([]*model.Group, *model.Paginator, error) { - var group model.GroupData - groupKey := model.GroupPath(params.Prefix) - - msgs, paginator, err := s.ListKVPaged(ctx, (&group).ProtoReflect().Type(), params, groupKey, false) - if msgs == nil { - return nil, paginator, err - } - return model.ConvertGroupDataList(msgs), paginator, err -} - -func (s *AuthService) AddUserToGroup(ctx context.Context, username, groupDisplayName string) error { - if _, err := s.GetUser(ctx, username); err != nil { - return err - } - if _, err := s.GetGroup(ctx, groupDisplayName); err != nil { - return err - } - - userKey := model.UserPath(username) - gu := model.GroupUserPath(groupDisplayName, username) - err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, gu, &kv.SecondaryIndex{PrimaryKey: userKey}, nil) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - err = ErrAlreadyExists - } - return fmt.Errorf("add user to group: (key %s): %w", gu, err) - } - return nil -} - -func (s *AuthService) removeUserFromGroupNoValidation(ctx context.Context, username, groupID string) error { - gu := model.GroupUserPath(groupID, username) - err := s.store.Delete(ctx, []byte(model.PartitionKey), gu) - if err != nil { - return fmt.Errorf("remove user from group: (key %s): %w", gu, err) - } - return nil -} - -func (s *AuthService) RemoveUserFromGroup(ctx context.Context, username, groupID string) error { - if _, err := s.GetUser(ctx, username); err != nil { - return err - } - if _, err := s.GetGroup(ctx, groupID); err != nil { - return err - } - return s.removeUserFromGroupNoValidation(ctx, username, groupID) -} - -func (s *AuthService) ListUserGroups(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Group, *model.Paginator, error) { - if _, err := s.GetUser(ctx, username); err != nil { - return nil, nil, err - } - if params.Amount < 0 || params.Amount > maxPage { - params.Amount = maxPage - } - - hasMoreGroups := true - afterGroup := params.After - var userGroups []*model.Group - resPaginator := model.Paginator{Amount: 0, NextPageToken: ""} - for hasMoreGroups && len(userGroups) <= params.Amount { - groups, paginator, err := s.ListGroups(ctx, &model.PaginationParams{Prefix: params.Prefix, After: afterGroup, Amount: maxPage}) - if err != nil { - return nil, nil, err - } - for _, group := range groups { - path := model.GroupUserPath(group.DisplayName, username) - m := kv.SecondaryIndex{} - _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, path, &m) - if err != nil && !errors.Is(err, kv.ErrNotFound) { - return nil, nil, err - } - if err == nil { - appendGroup := &model.Group{ - DisplayName: group.DisplayName, - ID: group.DisplayName, - CreatedAt: group.CreatedAt, - } - userGroups = append(userGroups, appendGroup) - } - if len(userGroups) == params.Amount { - resPaginator.NextPageToken = group.DisplayName - resPaginator.Amount = len(userGroups) - return userGroups, &resPaginator, nil - } - } - hasMoreGroups = paginator.NextPageToken != "" - afterGroup = paginator.NextPageToken - } - resPaginator.Amount = len(userGroups) - return userGroups, &resPaginator, nil -} - -func (s *AuthService) ListGroupUsers(ctx context.Context, groupID string, params *model.PaginationParams) ([]*model.User, *model.Paginator, error) { - if _, err := s.GetGroup(ctx, groupID); err != nil { - return nil, nil, err - } - var policy model.UserData - userGroupKey := model.GroupUserPath(groupID, params.Prefix) - - msgs, paginator, err := s.ListKVPaged(ctx, (&policy).ProtoReflect().Type(), params, userGroupKey, true) - if msgs == nil { - return nil, paginator, err - } - return model.ConvertUsersDataList(msgs), paginator, err -} - -func ValidatePolicy(policy *model.Policy) error { - if err := model.ValidateAuthEntityID(policy.DisplayName); err != nil { - return err - } - for _, stmt := range policy.Statement { - for _, action := range stmt.Action { - if err := model.ValidateActionName(action); err != nil { - return err - } - } - if err := model.ValidateArn(stmt.Resource); err != nil { - return err - } - if err := model.ValidateStatementEffect(stmt.Effect); err != nil { - return err - } - } - return nil -} - -func (s *AuthService) WritePolicy(ctx context.Context, policy *model.Policy, update bool) error { - if err := ValidatePolicy(policy); err != nil { - return err - } - policyKey := model.PolicyPath(policy.DisplayName) - m := model.ProtoFromPolicy(policy) - - if update { // update policy only if it already exists - err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, policyKey, m, kv.PrecondConditionalExists) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - err = ErrNotFound - } - return err - } - return nil - } - - // create policy only if it does not exist - err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, policyKey, m, nil) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - err = ErrAlreadyExists - } - return err - } - return nil -} - -func (s *AuthService) GetPolicy(ctx context.Context, policyDisplayName string) (*model.Policy, error) { - policyKey := model.PolicyPath(policyDisplayName) - p := model.PolicyData{} - _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, policyKey, &p) - if err != nil { - if errors.Is(err, kv.ErrNotFound) { - err = ErrNotFound - } - return nil, fmt.Errorf("%s: %w", policyDisplayName, err) - } - return model.PolicyFromProto(&p), nil -} - -func (s *AuthService) DeletePolicy(ctx context.Context, policyDisplayName string) error { - if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { - return err - } - policyPath := model.PolicyPath(policyDisplayName) - - // delete policy attachment to user - usersKey := model.UserPath("") - it, err := kv.NewPrimaryIterator(ctx, s.store, (&model.UserData{}).ProtoReflect().Type(), model.PartitionKey, usersKey, kv.IteratorOptionsAfter([]byte(""))) - if err != nil { - return err - } - defer it.Close() - for it.Next() { - entry := it.Entry() - user := entry.Value.(*model.UserData) - if err = s.DetachPolicyFromUserNoValidation(ctx, policyDisplayName, user.Username); err != nil { - return err - } - } - - // delete policy attachment to group - groupKey := model.GroupPath("") - it, err = kv.NewPrimaryIterator(ctx, s.store, (&model.GroupData{}).ProtoReflect().Type(), model.PartitionKey, groupKey, kv.IteratorOptionsAfter([]byte(""))) - if err != nil { - return err - } - defer it.Close() - for it.Next() { - entry := it.Entry() - group := entry.Value.(*model.GroupData) - if err = s.DetachPolicyFromGroupNoValidation(ctx, policyDisplayName, group.DisplayName); err != nil { - return err - } - } - - // delete policy - err = s.store.Delete(ctx, []byte(model.PartitionKey), policyPath) - if err != nil { - return fmt.Errorf("delete policy (policyKey %s): %w", policyPath, err) - } - return nil -} - -func (s *AuthService) ListPolicies(ctx context.Context, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { - var policy model.PolicyData - policyKey := model.PolicyPath(params.Prefix) - - msgs, paginator, err := s.ListKVPaged(ctx, (&policy).ProtoReflect().Type(), params, policyKey, false) - if msgs == nil { - return nil, paginator, err - } - return model.ConvertPolicyDataList(msgs), paginator, err -} - -func (s *AuthService) CreateCredentials(ctx context.Context, username string) (*model.Credential, error) { - accessKeyID := keys.GenAccessKeyID() - secretAccessKey := keys.GenSecretAccessKey() - return s.AddCredentials(ctx, username, accessKeyID, secretAccessKey) -} - -func (s *AuthService) AddCredentials(ctx context.Context, username, accessKeyID, secretAccessKey string) (*model.Credential, error) { - if !IsValidAccessKeyID(accessKeyID) { - return nil, ErrInvalidAccessKeyID - } - if len(secretAccessKey) == 0 { - return nil, ErrInvalidSecretAccessKey - } - now := time.Now() - encryptedKey, err := model.EncryptSecret(s.secretStore, secretAccessKey) - if err != nil { - return nil, err - } - user, err := s.GetUser(ctx, username) - if err != nil { - return nil, err - } - - c := &model.Credential{ - BaseCredential: model.BaseCredential{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - SecretAccessKeyEncryptedBytes: encryptedKey, - IssuedDate: now, - }, - Username: user.Username, - } - credentialsKey := model.CredentialPath(user.Username, c.AccessKeyID) - err = kv.SetMsgIf(ctx, s.store, model.PartitionKey, credentialsKey, model.ProtoFromCredential(c), nil) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - err = ErrAlreadyExists - } - return nil, fmt.Errorf("save credentials (credentialsKey %s): %w", credentialsKey, err) - } - - return c, nil -} - -func IsValidAccessKeyID(key string) bool { - l := len(key) - return l >= 3 && l <= 20 -} - -func (s *AuthService) DeleteCredentials(ctx context.Context, username, accessKeyID string) error { - if _, err := s.GetUser(ctx, username); err != nil { - return err - } - if _, err := s.GetCredentials(ctx, accessKeyID); err != nil { - return err - } - - credPath := model.CredentialPath(username, accessKeyID) - err := s.store.Delete(ctx, []byte(model.PartitionKey), credPath) - if err != nil { - return fmt.Errorf("delete credentials (credentialsKey %s): %w", credPath, err) - } - return nil -} - -func (s *AuthService) AttachPolicyToGroup(ctx context.Context, policyDisplayName, groupDisplayName string) error { - if _, err := s.GetGroup(ctx, groupDisplayName); err != nil { - return err - } - if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { - return err - } - - policyKey := model.PolicyPath(policyDisplayName) - pg := model.GroupPolicyPath(groupDisplayName, policyDisplayName) - - err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, pg, &kv.SecondaryIndex{PrimaryKey: policyKey}, nil) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - err = ErrAlreadyExists - } - return fmt.Errorf("policy attachment to group: (key %s): %w", pg, err) - } - return nil -} - -func (s *AuthService) DetachPolicyFromGroupNoValidation(ctx context.Context, policyDisplayName, groupDisplayName string) error { - pg := model.GroupPolicyPath(groupDisplayName, policyDisplayName) - err := s.store.Delete(ctx, []byte(model.PartitionKey), pg) - if err != nil { - return fmt.Errorf("policy detachment to group: (key %s): %w", pg, err) - } - return nil -} - -func (s *AuthService) DetachPolicyFromGroup(ctx context.Context, policyDisplayName, groupDisplayName string) error { - if _, err := s.GetGroup(ctx, groupDisplayName); err != nil { - return err - } - if _, err := s.GetPolicy(ctx, policyDisplayName); err != nil { - return err - } - return s.DetachPolicyFromGroupNoValidation(ctx, policyDisplayName, groupDisplayName) -} - -func (s *AuthService) GetCredentialsForUser(ctx context.Context, username, accessKeyID string) (*model.Credential, error) { - if _, err := s.GetUser(ctx, username); err != nil { - return nil, err - } - credentialsKey := model.CredentialPath(username, accessKeyID) - m := model.CredentialData{} - _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, credentialsKey, &m) - if err != nil { - if errors.Is(err, kv.ErrNotFound) { - err = ErrNotFound - } - return nil, err - } - - c, err := model.CredentialFromProto(s.secretStore, &m) - if err != nil { - return nil, err - } - c.SecretAccessKey = "" - return c, nil -} - -func (s *AuthService) GetCredentials(ctx context.Context, accessKeyID string) (*model.Credential, error) { - return s.cache.GetCredential(accessKeyID, func() (*model.Credential, error) { - m := &model.UserData{} - itr, err := kv.NewPrimaryIterator(ctx, s.store, m.ProtoReflect().Type(), model.PartitionKey, model.UserPath(""), kv.IteratorOptionsAfter([]byte(""))) - if err != nil { - return nil, fmt.Errorf("scan users: %w", err) - } - defer itr.Close() - - for itr.Next() { - entry := itr.Entry() - user, ok := entry.Value.(*model.UserData) - if !ok { - return nil, fmt.Errorf("failed to cast: %w", err) - } - c := model.CredentialData{} - credentialsKey := model.CredentialPath(user.Username, accessKeyID) - _, err := kv.GetMsg(ctx, s.store, model.PartitionKey, credentialsKey, &c) - if err != nil && !errors.Is(err, kv.ErrNotFound) { - return nil, err - } - if err == nil { - return model.CredentialFromProto(s.secretStore, &c) - } - } - if err = itr.Err(); err != nil { - return nil, err - } - return nil, fmt.Errorf("credentials %w", ErrNotFound) - }) -} - -func interpolateUser(resource string, username string) string { - return strings.ReplaceAll(resource, "${user}", username) -} - -func checkPermissions(ctx context.Context, node permissions.Node, username string, policies []*model.Policy) CheckResult { - allowed := CheckNeutral - switch node.Type { - case permissions.NodeTypeNode: - // check whether the permission is allowed, denied or natural (not allowed and not denied) - for _, policy := range policies { - for _, stmt := range policy.Statement { - resource := interpolateUser(stmt.Resource, username) - if !ArnMatch(resource, node.Permission.Resource) { - continue - } - for _, action := range stmt.Action { - if !wildcard.Match(action, node.Permission.Action) { - continue // not a matching action - } - - if stmt.Effect == model.StatementEffectDeny { - // this is a "Deny" and it takes precedence - return CheckDeny - } - - allowed = CheckAllow - } - } - } - - case permissions.NodeTypeOr: - // returns: - // Allowed - at least one of the permissions is allowed and no one is denied - // Denied - one of the permissions is Deny - // Natural - otherwise - for _, node := range node.Nodes { - result := checkPermissions(ctx, node, username, policies) - if result == CheckDeny { - return CheckDeny - } - if allowed != CheckAllow { - allowed = result - } - } - - case permissions.NodeTypeAnd: - // returns: - // Allowed - all the permissions are allowed - // Denied - one of the permissions is Deny - // Natural - otherwise - for _, node := range node.Nodes { - result := checkPermissions(ctx, node, username, policies) - if result == CheckNeutral || result == CheckDeny { - return result - } - } - return CheckAllow - - default: - logging.FromContext(ctx).Error("unknown permission node type") - return CheckDeny - } - return allowed -} - -func (s *AuthService) Authorize(ctx context.Context, req *AuthorizationRequest) (*AuthorizationResponse, error) { - policies, _, err := s.ListEffectivePolicies(ctx, req.Username, &model.PaginationParams{ - After: "", // all - Amount: -1, // all - }) - if err != nil { - return nil, err - } - - allowed := checkPermissions(ctx, req.RequiredPermissions, req.Username, policies) - - if allowed != CheckAllow { - return &AuthorizationResponse{ - Allowed: false, - Error: ErrInsufficientPermissions, - }, nil - } - - // we're allowed! - return &AuthorizationResponse{Allowed: true}, nil -} - -func (s *AuthService) ClaimTokenIDOnce(ctx context.Context, tokenID string, expiresAt int64) error { - return claimTokenIDOnce(ctx, tokenID, expiresAt, s.markTokenSingleUse) -} - -func claimTokenIDOnce(ctx context.Context, tokenID string, expiresAt int64, markTokenSingleUse func(context.Context, string, time.Time) (bool, error)) error { - tokenExpiresAt := time.Unix(expiresAt, 0) - canUseToken, err := markTokenSingleUse(ctx, tokenID, tokenExpiresAt) - if err != nil { - return err - } - if !canUseToken { - return ErrInvalidToken - } - return nil -} - -func (s *AuthService) IsExternalPrincipalsEnabled(ctx context.Context) bool { - return false -} - -func (s *AuthService) CreateUserExternalPrincipal(ctx context.Context, userID, principalID string) error { - return ErrNotImplemented -} - -func (s *AuthService) DeleteUserExternalPrincipal(ctx context.Context, userID, principalID string) error { - return ErrNotImplemented -} - -func (s *AuthService) GetExternalPrincipal(ctx context.Context, principalID string) (*model.ExternalPrincipal, error) { - return nil, ErrNotImplemented -} - -func (s *AuthService) ListUserExternalPrincipals(ctx context.Context, userID string, params *model.PaginationParams) ([]*model.ExternalPrincipal, *model.Paginator, error) { - return nil, nil, ErrNotImplemented -} - -// markTokenSingleUse returns true if token is valid for single use -func (s *AuthService) markTokenSingleUse(ctx context.Context, tokenID string, tokenExpiresAt time.Time) (bool, error) { - tokenPath := model.ExpiredTokenPath(tokenID) - m := model.TokenData{TokenId: tokenID, ExpiredAt: timestamppb.New(tokenExpiresAt)} - err := kv.SetMsgIf(ctx, s.store, model.PartitionKey, tokenPath, &m, nil) - if err != nil { - if errors.Is(err, kv.ErrPredicateFailed) { - return false, nil - } - return false, err - } - - if err := s.deleteTokens(ctx); err != nil { - s.log.WithError(err).Error("Failed to delete expired tokens") - } - return true, nil -} - -func (s *AuthService) deleteTokens(ctx context.Context) error { - it, err := kv.NewPrimaryIterator(ctx, s.store, (&model.TokenData{}).ProtoReflect().Type(), model.PartitionKey, model.ExpiredTokensPath(), kv.IteratorOptionsFrom([]byte(""))) - if err != nil { - return err - } - defer it.Close() - - deletionCutoff := time.Now() - for it.Next() { - msg := it.Entry() - if msg == nil { - return fmt.Errorf("nil token: %w", ErrInvalidToken) - } - token, ok := msg.Value.(*model.TokenData) - if token == nil || !ok { - return fmt.Errorf("wrong token type: %w", ErrInvalidToken) - } - - if token.ExpiredAt.AsTime().After(deletionCutoff) { - // reached a token with expiry greater than the cutoff, - // tokens are k-ordered (xid) hence we'll not find more expired tokens - return nil - } - - tokenPath := model.ExpiredTokenPath(token.TokenId) - if err := s.store.Delete(ctx, []byte(model.PartitionKey), tokenPath); err != nil { - return fmt.Errorf("deleting token: %w", err) - } - } - - return it.Err() -} - const ( healthCheckMaxInterval = 5 * time.Second healthCheckInitialInterval = 1 * time.Second @@ -1273,7 +221,7 @@ func userIDToInt(userID string) (int64, error) { return strconv.ParseInt(userID, base, bitSize) } -func (a *APIAuthService) getFirstUser(ctx context.Context, userKey userKey, params *ListUsersParams) (*model.User, error) { +func (a *APIAuthService) getFirstUser(ctx context.Context, userKey UserKey, params *ListUsersParams) (*model.User, error) { return a.cache.GetUser(userKey, func() (*model.User, error) { // fetch at least two users to make sure we don't have duplicates if params.Amount == nil { @@ -1317,12 +265,12 @@ func (a *APIAuthService) GetUserByID(ctx context.Context, userID string) (*model if err != nil { return nil, fmt.Errorf("userID as int64: %w", err) } - return a.getFirstUser(ctx, userKey{id: userID}, &ListUsersParams{Id: &intID}) + return a.getFirstUser(ctx, UserKey{id: userID}, &ListUsersParams{Id: &intID}) } func (a *APIAuthService) GetUser(ctx context.Context, username string) (*model.User, error) { ctx = httputil.SetClientTrace(ctx, "api_auth") - return a.cache.GetUser(userKey{username: username}, func() (*model.User, error) { + return a.cache.GetUser(UserKey{Username: username}, func() (*model.User, error) { resp, err := a.apiClient.GetUserWithResponse(ctx, username) if err != nil { a.logger.WithError(err).WithField("username", username).Error("failed to get user") @@ -1345,12 +293,12 @@ func (a *APIAuthService) GetUser(ctx context.Context, username string) (*model.U func (a *APIAuthService) GetUserByEmail(ctx context.Context, email string) (*model.User, error) { ctx = httputil.SetClientTrace(ctx, "api_auth") - return a.getFirstUser(ctx, userKey{email: email}, &ListUsersParams{Email: swag.String(email)}) + return a.getFirstUser(ctx, UserKey{Email: email}, &ListUsersParams{Email: swag.String(email)}) } func (a *APIAuthService) GetUserByExternalID(ctx context.Context, externalID string) (*model.User, error) { ctx = httputil.SetClientTrace(ctx, "api_auth") - return a.getFirstUser(ctx, userKey{externalID: externalID}, &ListUsersParams{ExternalId: swag.String(externalID)}) + return a.getFirstUser(ctx, UserKey{ExternalID: externalID}, &ListUsersParams{ExternalId: swag.String(externalID)}) } func toPagination(paginator Pagination) *model.Paginator { @@ -1882,7 +830,7 @@ func (a *APIAuthService) ListUserPolicies(ctx context.Context, username string, func (a *APIAuthService) listAllEffectivePolicies(ctx context.Context, username string) ([]*model.Policy, error) { hasMore := true after := "" - amount := maxPage + amount := MaxPage policies := make([]*model.Policy, 0) for hasMore { p, paginator, err := a.ListEffectivePolicies(ctx, username, &model.PaginationParams{ @@ -1970,7 +918,7 @@ func (a *APIAuthService) Authorize(ctx context.Context, req *AuthorizationReques return nil, err } - allowed := checkPermissions(ctx, req.RequiredPermissions, req.Username, policies) + allowed := CheckPermissions(ctx, req.RequiredPermissions, req.Username, policies) if allowed != CheckAllow { return &AuthorizationResponse{ @@ -2032,7 +980,7 @@ func (a *APIAuthService) CheckHealth(ctx context.Context, logger logging.Logger, return nil } -func (a *APIAuthService) IsExternalPrincipalsEnabled(ctx context.Context) bool { +func (a *APIAuthService) IsExternalPrincipalsEnabled(_ context.Context) bool { return a.externalPrincipalsEnabled } @@ -2197,3 +1145,68 @@ func NewAPIAuthServiceWithClient(client ClientWithResponsesInterface, externalPr externalPrincipalsEnabled: externalPrincipalseEnabled, }, nil } + +func CheckPermissions(ctx context.Context, node permissions.Node, username string, policies []*model.Policy) CheckResult { + allowed := CheckNeutral + switch node.Type { + case permissions.NodeTypeNode: + // check whether the permission is allowed, denied or natural (not allowed and not denied) + for _, policy := range policies { + for _, stmt := range policy.Statement { + resource := interpolateUser(stmt.Resource, username) + if !ArnMatch(resource, node.Permission.Resource) { + continue + } + for _, action := range stmt.Action { + if !wildcard.Match(action, node.Permission.Action) { + continue // not a matching action + } + + if stmt.Effect == model.StatementEffectDeny { + // this is a "Deny" and it takes precedence + return CheckDeny + } + + allowed = CheckAllow + } + } + } + + case permissions.NodeTypeOr: + // returns: + // Allowed - at least one of the permissions is allowed and no one is denied + // Denied - one of the permissions is Deny + // Natural - otherwise + for _, node := range node.Nodes { + result := CheckPermissions(ctx, node, username, policies) + if result == CheckDeny { + return CheckDeny + } + if allowed != CheckAllow { + allowed = result + } + } + + case permissions.NodeTypeAnd: + // returns: + // Allowed - all the permissions are allowed + // Denied - one of the permissions is Deny + // Natural - otherwise + for _, node := range node.Nodes { + result := CheckPermissions(ctx, node, username, policies) + if result == CheckNeutral || result == CheckDeny { + return result + } + } + return CheckAllow + + default: + logging.FromContext(ctx).Error("unknown permission node type") + return CheckDeny + } + return allowed +} + +func interpolateUser(resource string, username string) string { + return strings.ReplaceAll(resource, "${user}", username) +} diff --git a/pkg/auth/service_test.go b/pkg/auth/service_test.go index 20441144c88..8854405cd24 100644 --- a/pkg/auth/service_test.go +++ b/pkg/auth/service_test.go @@ -8,679 +8,31 @@ import ( "net/http" "net/http/httptest" "os" - "strings" "testing" "time" "github.com/go-openapi/swag" "github.com/go-test/deep" "github.com/golang/mock/gomock" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/treeverse/lakefs/pkg/auth" - "github.com/treeverse/lakefs/pkg/auth/acl" "github.com/treeverse/lakefs/pkg/auth/crypt" "github.com/treeverse/lakefs/pkg/auth/mock" "github.com/treeverse/lakefs/pkg/auth/model" authparams "github.com/treeverse/lakefs/pkg/auth/params" - authtestutil "github.com/treeverse/lakefs/pkg/auth/testutil" "github.com/treeverse/lakefs/pkg/httputil" - "github.com/treeverse/lakefs/pkg/kv/kvtest" "github.com/treeverse/lakefs/pkg/logging" - "github.com/treeverse/lakefs/pkg/permissions" "github.com/treeverse/lakefs/pkg/testutil" ) const creationDate = 12345678 -var ( - someSecret = []byte("some secret") - - userPoliciesForTesting = []*model.Policy{ - { - Statement: model.Statements{ - { - Action: []string{"auth:DeleteUser"}, - Resource: "arn:lakefs:auth:::user/foobar", - Effect: model.StatementEffectAllow, - }, - { - Action: []string{"auth:*"}, - Resource: "*", - Effect: model.StatementEffectDeny, - }, - }, - }, - } -) - func TestMain(m *testing.M) { logging.SetLevel("panic") code := m.Run() os.Exit(code) } -func userWithPolicies(t testing.TB, s auth.Service, policies []*model.Policy) string { - t.Helper() - ctx := context.Background() - userName := uuid.New().String() - _, err := s.CreateUser(ctx, &model.User{ - Username: userName, - }) - if err != nil { - t.Fatal(err) - } - for _, policy := range policies { - if policy.DisplayName == "" { - policy.DisplayName = model.CreateID() - } - err := s.WritePolicy(ctx, policy, false) - if err != nil { - t.Fatal(err) - } - err = s.AttachPolicyToUser(ctx, policy.DisplayName, userName) - if err != nil { - t.Fatal(err) - } - } - - return userName -} - -func userWithACLs(t testing.TB, s auth.Service, a model.ACL) string { - t.Helper() - statements, err := acl.ACLToStatement(a) - if err != nil { - t.Fatal("ACLToStatement: ", err) - } - creationTime := time.Unix(creationDate, 0) - - policy := &model.Policy{ - CreatedAt: creationTime, - DisplayName: model.CreateID(), - Statement: statements, - ACL: a, - } - return userWithPolicies(t, s, []*model.Policy{policy}) -} - -func TestAuthService_ListUsers_PagedWithPrefix(t *testing.T) { - ctx := context.Background() - kvStore := kvtest.GetStore(ctx, t) - s := auth.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ - Enabled: false, - }, logging.ContextUnavailable()) - - users := []string{"bar", "barn", "baz", "foo", "foobar", "foobaz"} - for _, u := range users { - user := model.User{Username: u} - if _, err := s.CreateUser(ctx, &user); err != nil { - t.Fatalf("create user: %s", err) - } - } - - sizes := []int{10, 3, 2} - prefixes := []string{"b", "ba", "bar", "f", "foo", "foob", "foobar"} - for _, size := range sizes { - for _, p := range prefixes { - t.Run(fmt.Sprintf("Size:%d;Prefix:%s", size, p), func(t *testing.T) { - // Only count the correct number of entries were - // returned; values are tested below. - got := 0 - after := "" - for { - value, paginator, err := s.ListUsers(ctx, &model.PaginationParams{Amount: size, Prefix: p, After: after}) - if err != nil { - t.Fatal(err) - } - got += len(value) - after = paginator.NextPageToken - if after == "" { - break - } - } - // Verify got the right number of users - count := 0 - for _, u := range users { - if strings.HasPrefix(u, p) { - count++ - } - } - if got != count { - t.Errorf("Got %d users when expecting %d", got, count) - } - }) - } - } -} - -func TestAuthService_ListPaged(t *testing.T) { - ctx := context.Background() - kvStore := kvtest.GetStore(ctx, t) - s := auth.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ - Enabled: false, - }, logging.ContextUnavailable()) - - const chars = "abcdefghijklmnopqrstuvwxyz" - for _, c := range chars { - user := model.User{Username: string(c)} - if _, err := s.CreateUser(ctx, &user); err != nil { - t.Fatalf("create user: %s", err) - } - } - var userData model.UserData - - for size := 0; size <= len(chars)+1; size++ { - t.Run(fmt.Sprintf("PageSize%d", size), func(t *testing.T) { - pagination := &model.PaginationParams{Amount: size} - if size == 0 { // Overload to mean "don't paginate" - pagination.Amount = -1 - } - got := "" - for { - values, paginator, err := s.ListKVPaged(ctx, (&userData).ProtoReflect().Type(), pagination, model.UserPath(""), false) - if err != nil { - t.Errorf("ListPaged: %s", err) - break - } - if values == nil { - t.Fatalf("expected values for pagination %+v but got just paginator %+v", pagination, paginator) - } - letters := model.ConvertUsersDataList(values) - for _, c := range letters { - got = got + c.Username - } - if paginator.NextPageToken == "" { - if size > 0 && len(letters) > size { - t.Errorf("expected at most %d entries in last page but got %d", size, len(letters)) - } - break - } - if len(letters) != size { - t.Errorf("expected %d entries in page but got %d", size, len(letters)) - } - pagination.After = paginator.NextPageToken - } - if got != chars { - t.Errorf("Expected to read back \"%s\" but got \"%s\"", chars, got) - } - }) - } -} - -func TestAuthService_DeleteUserWithRelations(t *testing.T) { - userNames := []string{"first", "second"} - groupNames := []string{"groupA", "groupB"} - policyNames := []string{"policy01", "policy02", "policy03", "policy04"} - - ctx := context.Background() - authService, _ := authtestutil.SetupService(t, ctx, someSecret) - - // create initial data set and verify users groups and policies are created and related as expected - createInitialDataSet(t, ctx, authService, userNames, groupNames, policyNames) - users, _, err := authService.ListUsers(ctx, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, users) - require.Equal(t, len(userNames), len(users)) - for _, userName := range userNames { - user, err := authService.GetUser(ctx, userName) - require.NoError(t, err) - require.NotNil(t, user) - require.Equal(t, userName, user.Username) - - groups, _, err := authService.ListUserGroups(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, groups) - require.Equal(t, len(groupNames), len(groups)) - - policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)/2, len(policies)) - - policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames), len(policies)) - } - for _, groupName := range groupNames { - users, _, err := authService.ListGroupUsers(ctx, groupName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, users) - require.Equal(t, len(userNames), len(users)) - } - - // delete a user - err = authService.DeleteUser(ctx, userNames[0]) - require.NoError(t, err) - - // verify user does not exist - user, err := authService.GetUser(ctx, userNames[0]) - require.Error(t, err) - require.Nil(t, user) - - // verify user is removed from all lists and relations - users, _, err = authService.ListUsers(ctx, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, users) - require.Equal(t, len(userNames)-1, len(users)) - - for _, groupName := range groupNames { - users, _, err := authService.ListGroupUsers(ctx, groupName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, users) - require.Equal(t, len(userNames)-1, len(users)) - for _, user := range users { - require.NotEqual(t, userNames[0], user.Username) - } - } -} - -func TestAuthService_DeleteGroupWithRelations(t *testing.T) { - userNames := []string{"first", "second", "third"} - groupNames := []string{"groupA", "groupB", "groupC"} - policyNames := []string{"policy01", "policy02", "policy03", "policy04"} - - ctx := context.Background() - authService, _ := authtestutil.SetupService(t, ctx, someSecret) - - // create initial data set and verify users groups and policies are created and related as expected - createInitialDataSet(t, ctx, authService, userNames, groupNames, policyNames) - groups, _, err := authService.ListGroups(ctx, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, groups) - require.Equal(t, len(groupNames), len(groups)) - for _, userName := range userNames { - user, err := authService.GetUser(ctx, userName) - require.NoError(t, err) - require.NotNil(t, user) - require.Equal(t, userName, user.Username) - - groups, _, err := authService.ListUserGroups(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, groups) - require.Equal(t, len(groupNames), len(groups)) - - policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)/2, len(policies)) - - policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames), len(policies)) - } - for _, groupName := range groupNames { - group, err := authService.GetGroup(ctx, groupName) - require.NoError(t, err) - require.NotNil(t, group) - require.Equal(t, groupName, group.DisplayName) - - users, _, err := authService.ListGroupUsers(ctx, groupName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, users) - require.Equal(t, len(userNames), len(users)) - - policies, _, err := authService.ListGroupPolicies(ctx, groupName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)-len(policyNames)/2, len(policies)) - } - for _, userName := range userNames { - groups, _, err := authService.ListUserGroups(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, groups) - require.Equal(t, len(groupNames), len(groups)) - } - - // delete a group - err = authService.DeleteGroup(ctx, groupNames[1]) - require.NoError(t, err) - - // verify group does not exist - group, err := authService.GetGroup(ctx, groupNames[1]) - require.Error(t, err) - require.Nil(t, group) - - // verify group is removed from all lists and relations - groups, _, err = authService.ListGroups(ctx, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, groups) - require.Equal(t, len(groupNames)-1, len(groups)) - - for _, userName := range userNames { - groups, _, err := authService.ListUserGroups(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, groups) - require.Equal(t, len(userNames)-1, len(groups)) - for _, group := range groups { - require.NotEqual(t, groupNames[1], group.DisplayName) - } - } -} - -func TestAuthService_DeletePoliciesWithRelations(t *testing.T) { - userNames := []string{"first", "second", "third"} - groupNames := []string{"groupA", "groupB", "groupC"} - policyNames := []string{"policy01", "policy02", "policy03", "policy04"} - - ctx := context.Background() - authService, _ := authtestutil.SetupService(t, ctx, someSecret) - - // create initial data set and verify users groups and policies are created and related as expected - createInitialDataSet(t, ctx, authService, userNames, groupNames, policyNames) - policies, _, err := authService.ListPolicies(ctx, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames), len(policies)) - for _, policyName := range policyNames { - policy, err := authService.GetPolicy(ctx, policyName) - require.NoError(t, err) - require.NotNil(t, policy) - require.Equal(t, policyName, policy.DisplayName) - } - - for _, groupName := range groupNames { - policies, _, err := authService.ListGroupPolicies(ctx, groupName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)-len(policyNames)/2, len(policies)) - } - for _, userName := range userNames { - policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)/2, len(policies)) - - policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames), len(policies)) - } - - // delete a user policy (beginning of the name list) - err = authService.DeletePolicy(ctx, policyNames[0]) - require.NoError(t, err) - - // verify policy does not exist - policy, err := authService.GetPolicy(ctx, policyNames[0]) - require.Error(t, err) - require.Nil(t, policy) - - // verify policy is removed from all lists and relations - policies, _, err = authService.ListPolicies(ctx, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)-1, len(policies)) - - for _, userName := range userNames { - policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)/2-1, len(policies)) - for _, policy := range policies { - require.NotEqual(t, policyNames[0], policy.DisplayName) - } - - policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)-1, len(policies)) - for _, policy := range policies { - require.NotEqual(t, policyNames[0], policy.DisplayName) - } - } - - for _, groupName := range groupNames { - policies, _, err := authService.ListGroupPolicies(ctx, groupName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)-len(policyNames)/2, len(policies)) - for _, policy := range policies { - require.NotEqual(t, policyNames[0], policy.DisplayName) - } - } - - // delete a group policy (end of the names list) - err = authService.DeletePolicy(ctx, policyNames[len(policyNames)-1]) - require.NoError(t, err) - - // verify policy does not exist - policy, err = authService.GetPolicy(ctx, policyNames[len(policyNames)-1]) - require.Error(t, err) - require.Nil(t, policy) - - // verify policy is removed from all lists and relations - policies, _, err = authService.ListPolicies(ctx, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)-2, len(policies)) - - for _, userName := range userNames { - policies, _, err := authService.ListUserPolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)/2-1, len(policies)) - for _, policy := range policies { - require.NotEqual(t, policyNames[len(policyNames)-1], policy.DisplayName) - } - - policies, _, err = authService.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)-2, len(policies)) - for _, policy := range policies { - require.NotEqual(t, policyNames[len(policyNames)-1], policy.DisplayName) - } - } - - for _, groupName := range groupNames { - policies, _, err := authService.ListGroupPolicies(ctx, groupName, &model.PaginationParams{Amount: 100}) - require.NoError(t, err) - require.NotNil(t, policies) - require.Equal(t, len(policyNames)-len(policyNames)/2-1, len(policies)) - for _, policy := range policies { - require.NotEqual(t, policyNames[len(policyNames)-1], policy.DisplayName) - } - } -} - -// createInitialDataSet - -// Creates K users with 2 credentials each, L groups and M policies -// Add all users to all groups -// Attach M/2 of the policies to all K users and the other M-M/2 policies to all L groups -func createInitialDataSet(t *testing.T, ctx context.Context, svc auth.Service, userNames, groupNames, policyNames []string) { - for _, userName := range userNames { - if _, err := svc.CreateUser(ctx, &model.User{Username: userName}); err != nil { - t.Fatalf("CreateUser(%s): %s", userName, err) - } - for i := 0; i < 2; i++ { - _, err := svc.CreateCredentials(ctx, userName) - if err != nil { - t.Errorf("CreateCredentials(%s): %s", userName, err) - } - } - } - - for _, groupName := range groupNames { - if _, err := svc.CreateGroup(ctx, &model.Group{DisplayName: groupName}); err != nil { - t.Fatalf("CreateGroup(%s): %s", groupName, err) - } - for _, userName := range userNames { - if err := svc.AddUserToGroup(ctx, userName, groupName); err != nil { - t.Fatalf("AddUserToGroup(%s, %s): %s", userName, groupName, err) - } - } - } - - numPolicies := len(policyNames) - for i, policyName := range policyNames { - if err := svc.WritePolicy(ctx, &model.Policy{DisplayName: policyName, Statement: userPoliciesForTesting[0].Statement}, false); err != nil { - t.Fatalf("WritePolicy(%s): %s", policyName, err) - } - if i < numPolicies/2 { - for _, userName := range userNames { - if err := svc.AttachPolicyToUser(ctx, policyName, userName); err != nil { - t.Fatalf("AttachPolicyToUser(%s, %s): %s", policyName, userName, err) - } - } - } else { - for _, groupName := range groupNames { - if err := svc.AttachPolicyToGroup(ctx, policyName, groupName); err != nil { - t.Fatalf("AttachPolicyToGroup(%s, %s): %s", policyName, groupName, err) - } - } - } - } -} - -func BenchmarkKVAuthService_ListEffectivePolicies(b *testing.B) { - // setup user with policies for benchmark - ctx := context.Background() - kvStore := kvtest.GetStore(ctx, b) - - serviceWithoutCache := auth.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ - Enabled: false, - }, logging.ContextUnavailable()) - serviceWithCache := auth.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ - Enabled: true, - Size: 1024, - TTL: 20 * time.Second, - Jitter: 3 * time.Second, - }, logging.ContextUnavailable()) - serviceWithCacheLowTTL := auth.NewAuthService(kvStore, crypt.NewSecretStore(someSecret), authparams.ServiceCache{ - Enabled: true, - Size: 1024, - TTL: 1 * time.Millisecond, - Jitter: 1 * time.Millisecond, - }, logging.ContextUnavailable()) - userName := userWithPolicies(b, serviceWithoutCache, userPoliciesForTesting) - - b.Run("without_cache", func(b *testing.B) { - benchmarkKVListEffectivePolicies(b, serviceWithoutCache, userName) - }) - b.Run("with_cache", func(b *testing.B) { - benchmarkKVListEffectivePolicies(b, serviceWithCache, userName) - }) - b.Run("without_cache_low_ttl", func(b *testing.B) { - benchmarkKVListEffectivePolicies(b, serviceWithCacheLowTTL, userName) - }) -} - -func benchmarkKVListEffectivePolicies(b *testing.B, s *auth.AuthService, userName string) { - b.ResetTimer() - ctx := context.Background() - for n := 0; n < b.N; n++ { - _, _, err := s.ListEffectivePolicies(ctx, userName, &model.PaginationParams{Amount: -1}) - if err != nil { - b.Fatal("Failed to list effective policies", err) - } - } -} - -func describeAllowed(allowed bool) string { - if allowed { - return "allowed" - } - return "forbidden" -} - -func TestACL(t *testing.T) { - hierarchy := []model.ACLPermission{acl.ReadPermission, acl.WritePermission, acl.SuperPermission, acl.AdminPermission} - - type PermissionFrom map[model.ACLPermission][]permissions.Permission - type TestCase struct { - // Name is an identifier for this test case. - Name string - // ACL is the ACL to test. ACL.Permission will be tested - // with each of the hierarchies. - ACL model.ACL - // PermissionFrom holds permissions that must hold starting - // at the ACLPermission key in the hierarchy. - PermissionFrom PermissionFrom - } - - tests := []TestCase{ - { - Name: "all repos", - ACL: model.ACL{}, - PermissionFrom: PermissionFrom{ - acl.ReadPermission: []permissions.Permission{ - {Action: permissions.ReadObjectAction, Resource: permissions.ObjectArn("foo", "some/path")}, - {Action: permissions.ListObjectsAction, Resource: permissions.ObjectArn("foo", "some/path")}, - {Action: permissions.ListObjectsAction, Resource: permissions.ObjectArn("quux", "")}, - {Action: permissions.CreateCredentialsAction, Resource: permissions.UserArn("${user}")}, - }, - acl.WritePermission: []permissions.Permission{ - {Action: permissions.WriteObjectAction, Resource: permissions.ObjectArn("foo", "some/path")}, - {Action: permissions.DeleteObjectAction, Resource: permissions.ObjectArn("foo", "some/path")}, - {Action: permissions.CreateBranchAction, Resource: permissions.BranchArn("foo", "twig")}, - {Action: permissions.CreateCommitAction, Resource: permissions.BranchArn("foo", "twig")}, - {Action: permissions.CreateMetaRangeAction, Resource: permissions.RepoArn("foo")}, - }, - acl.SuperPermission: []permissions.Permission{ - {Action: permissions.AttachStorageNamespaceAction, Resource: permissions.StorageNamespace("storage://bucket/path")}, - {Action: permissions.ImportFromStorageAction, Resource: permissions.StorageNamespace("storage://bucket/path")}, - {Action: permissions.ImportCancelAction, Resource: permissions.BranchArn("foo", "twig")}, - }, - acl.AdminPermission: []permissions.Permission{ - {Action: permissions.CreateUserAction, Resource: permissions.UserArn("you")}, - }, - }, - }, - } - - ctx := context.Background() - - for _, tt := range tests { - t.Run(tt.Name, func(t *testing.T) { - s, _ := authtestutil.SetupService(t, ctx, someSecret) - userID := make(map[model.ACLPermission]string, len(hierarchy)) - for _, aclPermission := range hierarchy { - tt.ACL.Permission = aclPermission - userID[aclPermission] = userWithACLs(t, s, tt.ACL) - } - tt.ACL.Permission = "" - - for from, pp := range tt.PermissionFrom { - for _, p := range pp { - t.Run(fmt.Sprintf("%+v", p), func(t *testing.T) { - n := permissions.Node{Permission: p} - allow := false - for _, aclPermission := range hierarchy { - t.Run(string(aclPermission), func(t *testing.T) { - if aclPermission == from { - allow = true - } - origResource := n.Permission.Resource - defer func() { - n.Permission.Resource = origResource - }() - n.Permission.Resource = strings.ReplaceAll(n.Permission.Resource, "${user}", userID[aclPermission]) - - r, err := s.Authorize(ctx, &auth.AuthorizationRequest{ - Username: userID[aclPermission], - RequiredPermissions: n, - }) - if err != nil { - t.Errorf("Authorize failed: %v", err) - } - if (allow && r.Error != nil) || !allow && !errors.Is(r.Error, auth.ErrInsufficientPermissions) { - t.Errorf("Authorization response error: %v", err) - } - if r.Allowed != allow { - t.Errorf("%s but expected %s", describeAllowed(r.Allowed), describeAllowed(allow)) - } - }) - } - }) - } - } - }) - } -} - func TestAPIAuthService_GetUserById(t *testing.T) { mockClient, s := NewTestApiService(t, false) tests := []struct { @@ -2767,7 +2119,7 @@ func TestAPIService_RequestIDPropagation(t *testing.T) { ctx := context.WithValue(context.Background(), httputil.RequestIDContextKey, requestID) - service.DeleteUser(ctx, "foo") + require.NoError(t, service.DeleteUser(ctx, "foo")) if !called { t.Error("Expected inner server to be called but it wasn't") } diff --git a/pkg/kv/migrations/migrations_test.go b/pkg/kv/migrations/migrations_test.go index 5cccb6529b5..9a99111021b 100644 --- a/pkg/kv/migrations/migrations_test.go +++ b/pkg/kv/migrations/migrations_test.go @@ -10,11 +10,12 @@ import ( "github.com/go-test/deep" "github.com/stretchr/testify/require" + authacl "github.com/treeverse/lakefs/contrib/auth/acl" + authtestutil "github.com/treeverse/lakefs/contrib/auth/acl/testutil" "github.com/treeverse/lakefs/pkg/auth" "github.com/treeverse/lakefs/pkg/auth/acl" "github.com/treeverse/lakefs/pkg/auth/model" "github.com/treeverse/lakefs/pkg/auth/setup" - authtestutil "github.com/treeverse/lakefs/pkg/auth/testutil" "github.com/treeverse/lakefs/pkg/config" "github.com/treeverse/lakefs/pkg/kv/migrations" "github.com/treeverse/lakefs/pkg/permissions" @@ -350,7 +351,7 @@ func createARN(name string) string { return fmt.Sprintf("arn:%s:this:is:an:arn", name) } -func verifyMigration(t *testing.T, ctx context.Context, authService *auth.AuthService, policies []model.Policy, cfg config.Config) { +func verifyMigration(t *testing.T, ctx context.Context, authService *authacl.AuthService, policies []model.Policy, cfg config.Config) { for _, prev := range policies { policy, err := authService.GetPolicy(ctx, prev.DisplayName) testutil.MustDo(t, "get policy", err) diff --git a/pkg/kv/migrations/rbac_to_acl.go b/pkg/kv/migrations/rbac_to_acl.go index 410639ea1d3..fa579a87f2f 100644 --- a/pkg/kv/migrations/rbac_to_acl.go +++ b/pkg/kv/migrations/rbac_to_acl.go @@ -8,6 +8,7 @@ import ( "time" "github.com/hashicorp/go-multierror" + authacl "github.com/treeverse/lakefs/contrib/auth/acl" "github.com/treeverse/lakefs/pkg/auth" "github.com/treeverse/lakefs/pkg/auth/acl" "github.com/treeverse/lakefs/pkg/auth/crypt" @@ -70,7 +71,7 @@ func MigrateToACL(ctx context.Context, kvStore kv.Store, cfg *config.Config, log usersWithPolicies []string ) updateTime := time.Now() - authService := auth.NewAuthService( + authService := authacl.NewAuthService( kvStore, crypt.NewSecretStore([]byte(cfg.Auth.Encrypt.SecretKey)), authparams.ServiceCache(cfg.Auth.Cache), diff --git a/pkg/loadtest/local_load_test.go b/pkg/loadtest/local_load_test.go index 4c78735a40b..2b34a5c21a1 100644 --- a/pkg/loadtest/local_load_test.go +++ b/pkg/loadtest/local_load_test.go @@ -8,9 +8,8 @@ import ( "testing" "time" - "github.com/treeverse/lakefs/pkg/authentication" - "github.com/spf13/viper" + "github.com/treeverse/lakefs/contrib/auth/acl" "github.com/treeverse/lakefs/pkg/actions" "github.com/treeverse/lakefs/pkg/api" "github.com/treeverse/lakefs/pkg/auth" @@ -18,6 +17,7 @@ import ( authmodel "github.com/treeverse/lakefs/pkg/auth/model" authparams "github.com/treeverse/lakefs/pkg/auth/params" "github.com/treeverse/lakefs/pkg/auth/setup" + "github.com/treeverse/lakefs/pkg/authentication" "github.com/treeverse/lakefs/pkg/block" "github.com/treeverse/lakefs/pkg/catalog" "github.com/treeverse/lakefs/pkg/config" @@ -51,7 +51,7 @@ func TestLocalLoad(t *testing.T) { } kvStore := kvtest.GetStore(ctx, t) - authService := auth.NewAuthService(kvStore, crypt.NewSecretStore([]byte("some secret")), authparams.ServiceCache{}, logging.ContextUnavailable().WithField("service", "auth")) + authService := acl.NewAuthService(kvStore, crypt.NewSecretStore([]byte("some secret")), authparams.ServiceCache{}, logging.ContextUnavailable().WithField("service", "auth")) meta := auth.NewKVMetadataManager("local_load_test", conf.Installation.FixedID, conf.Database.Type, kvStore) blockstoreType := os.Getenv(testutil.EnvKeyUseBlockAdapter)