diff --git a/filehandlers/imgs/handler.go b/filehandlers/imgs/handler.go index b8354c35..669f479c 100644 --- a/filehandlers/imgs/handler.go +++ b/filehandlers/imgs/handler.go @@ -48,7 +48,7 @@ func NewUploadImgHandler(dbpool db.DB, cfg *shared.ConfigSite, storage storage.S } func (h *UploadImgHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) { - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return nil, nil, err } @@ -88,7 +88,7 @@ func (h *UploadImgHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.F func (h *UploadImgHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) { logger := h.Cfg.Logger - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { logger.Error("could not get user from ctx", "err", err.Error()) return "", err @@ -145,10 +145,7 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (str logger.Info("unable to find image, continuing", "filename", nextPost.Filename, "err", err.Error()) } - featureFlag, err := shared.GetFeatureFlag(s.Context()) - if err != nil { - return "", err - } + featureFlag := shared.FindPlusFF(h.DBPool, h.Cfg, user.ID) metadata := PostMetaData{ OrigText: text, Post: &nextPost, @@ -192,7 +189,7 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (str } func (h *UploadImgHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error { - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return err } diff --git a/filehandlers/imgs/img.go b/filehandlers/imgs/img.go index 45c44a86..23cc0aaa 100644 --- a/filehandlers/imgs/img.go +++ b/filehandlers/imgs/img.go @@ -80,7 +80,7 @@ func (h *UploadImgHandler) writeImg(s ssh.Session, data *PostMetaData) error { if !valid { return err } - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return err } diff --git a/filehandlers/post_handler.go b/filehandlers/post_handler.go index 63f57c05..63dd8618 100644 --- a/filehandlers/post_handler.go +++ b/filehandlers/post_handler.go @@ -47,7 +47,7 @@ func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks, } func (h *ScpUploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) { - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return nil, nil, err } @@ -76,7 +76,7 @@ func (h *ScpUploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.F func (h *ScpUploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) { logger := h.Cfg.Logger - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { logger.Error("error getting user from ctx", "err", err.Error()) return "", err @@ -263,7 +263,7 @@ func (h *ScpUploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (str func (h *ScpUploadHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error { logger := h.Cfg.Logger - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { logger.Error("could not get user from ctx", "err", err.Error()) return err diff --git a/filehandlers/router_handler.go b/filehandlers/router_handler.go index 9b7be132..241500b0 100644 --- a/filehandlers/router_handler.go +++ b/filehandlers/router_handler.go @@ -82,7 +82,7 @@ func (r *FileHandlerRouter) Read(s ssh.Session, entry *utils.FileEntry) (os.File func BaseList(s ssh.Session, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) { var fileList []os.FileInfo - user, err := shared.GetUser(s.Context()) + user, err := dbpool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return fileList, err } @@ -153,7 +153,7 @@ func (r *FileHandlerRouter) GetLogger() *slog.Logger { } func (r *FileHandlerRouter) Validate(s ssh.Session) error { - user, err := shared.GetUser(s.Context()) + user, err := r.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return err } diff --git a/go.mod b/go.mod index 9df38dd5..9b880a8b 100644 --- a/go.mod +++ b/go.mod @@ -53,7 +53,7 @@ require ( go.abhg.dev/goldmark/anchor v0.1.1 go.abhg.dev/goldmark/hashtag v0.3.1 go.abhg.dev/goldmark/toc v0.10.0 - golang.org/x/crypto v0.29.0 + golang.org/x/crypto v0.31.0 gopkg.in/yaml.v2 v2.4.0 ) @@ -179,10 +179,10 @@ require ( github.com/yusufpapurcu/wmi v1.2.4 // indirect golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f // indirect golang.org/x/net v0.31.0 // indirect - golang.org/x/sync v0.9.0 // indirect - golang.org/x/sys v0.27.0 // indirect - golang.org/x/term v0.26.0 // indirect - golang.org/x/text v0.20.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/term v0.27.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.8.0 // indirect google.golang.org/protobuf v1.35.2 // indirect mvdan.cc/xurls/v2 v2.5.0 // indirect diff --git a/go.sum b/go.sum index 9a046016..af1da562 100644 --- a/go.sum +++ b/go.sum @@ -391,8 +391,8 @@ golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= -golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -417,8 +417,8 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= -golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -438,8 +438,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -447,16 +447,16 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= -golang.org/x/term v0.26.0 h1:WEQa6V3Gja/BhNxg540hBip/kkaYtRg3cxg4oXSw4AU= -golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= -golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.4.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/pgs/tunnel.go b/pgs/tunnel.go index e4e59bb8..547a4b3b 100644 --- a/pgs/tunnel.go +++ b/pgs/tunnel.go @@ -7,7 +7,6 @@ import ( "github.com/charmbracelet/ssh" "github.com/picosh/pico/db" "github.com/picosh/pico/shared" - "github.com/picosh/utils" ) type TunnelWebRouter struct { @@ -39,16 +38,14 @@ func createHttpHandler(apiConfig *shared.ApiConfig) CtxHttpBridge { "impersonating", asUser, ) - pubkey, err := shared.GetPublicKey(ctx) - if err != nil { - log.Error(err.Error(), "subdomain", subdomain) + pubkey := ctx.Permissions().Extensions["pubkey"] + if pubkey == "" { + log.Error("pubkey not found in extensions", "subdomain", subdomain) return http.HandlerFunc(shared.UnauthorizedHandler) } - pubkeyStr := utils.KeyForKeyText(pubkey) - log = log.With( - "pubkey", pubkeyStr, + "pubkey", pubkey, ) props, err := shared.GetProjectFromSubdomain(subdomain) @@ -72,7 +69,7 @@ func createHttpHandler(apiConfig *shared.ApiConfig) CtxHttpBridge { return http.HandlerFunc(shared.UnauthorizedHandler) } - requester, _ := dbh.FindUserForKey("", pubkeyStr) + requester, _ := dbh.FindUserForKey("", pubkey) if requester != nil { log = log.With( "requester", requester.Name, @@ -89,33 +86,18 @@ func createHttpHandler(apiConfig *shared.ApiConfig) CtxHttpBridge { requester, _ = dbh.FindUserForName(asUser) } - shared.SetUser(ctx, requester) - - if !HasProjectAccess(project, owner, requester, pubkey) { + ctx.Permissions().Extensions["user_id"] = requester.ID + publicKey, err := ssh.ParsePublicKey([]byte(pubkey)) + if err != nil { + return http.HandlerFunc(shared.UnauthorizedHandler) + } + if !HasProjectAccess(project, owner, requester, publicKey) { log.Error("no access") return http.HandlerFunc(shared.UnauthorizedHandler) } log.Info("user has access to site") - /* routes := []shared.Route{ - // special API endpoint for tunnel users accessing site - shared.NewCorsRoute("GET", "/api/current_user", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - user, err := shared.GetUser(ctx) - if err != nil { - logger.Error("could not find user", "err", err.Error()) - shared.JSONError(w, err.Error(), http.StatusNotFound) - return - } - pico := shared.NewUserApi(user, pubkey) - err = json.NewEncoder(w).Encode(pico) - if err != nil { - log.Error(err.Error()) - } - }), - } */ - routes := NewWebRouter( apiConfig.Cfg, logger, diff --git a/pgs/uploader.go b/pgs/uploader.go index 44f9aa46..f32e5234 100644 --- a/pgs/uploader.go +++ b/pgs/uploader.go @@ -121,7 +121,7 @@ func (h *UploadAssetHandler) GetLogger() *slog.Logger { } func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) { - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return nil, nil, err } @@ -155,7 +155,7 @@ func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) { var fileList []os.FileInfo - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return fileList, err } @@ -197,7 +197,7 @@ func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recur } func (h *UploadAssetHandler) Validate(s ssh.Session) error { - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return err } @@ -248,7 +248,7 @@ func (h *UploadAssetHandler) findDenylist(bucket sst.Bucket, project *db.Project } func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) { - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if user == nil || err != nil { h.Cfg.Logger.Error("user not found in ctx", "err", err.Error()) return "", err @@ -314,11 +314,7 @@ func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (s return "", err } - featureFlag, err := shared.GetFeatureFlag(s.Context()) - if err != nil { - return "", err - } - + featureFlag := shared.FindPlusFF(h.DBPool, h.Cfg, user.ID) // calculate the filsize difference between the same file already // stored and the updated file being uploaded assetFilename := shared.GetAssetFileName(entry) @@ -424,7 +420,7 @@ func isSpecialFile(entry *sendutils.FileEntry) bool { } func (h *UploadAssetHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error { - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { h.Cfg.Logger.Error("user not found in ctx", "err", err.Error()) return err diff --git a/pico/file_handler.go b/pico/file_handler.go index 15a1d65c..1e61d901 100644 --- a/pico/file_handler.go +++ b/pico/file_handler.go @@ -56,7 +56,7 @@ func (h *UploadHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error } func (h *UploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) { - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return nil, nil, err } @@ -80,7 +80,7 @@ func (h *UploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.File func (h *UploadHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) { var fileList []os.FileInfo - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { return fileList, err } @@ -135,7 +135,7 @@ func (h *UploadHandler) Validate(s ssh.Session) error { return fmt.Errorf("must have username set") } - shared.SetUser(s.Context(), user) + s.Permissions().Extensions["user_id"] = user.ID return nil } @@ -276,7 +276,7 @@ func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, func (h *UploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) { logger := h.Cfg.Logger - user, err := shared.GetUser(s.Context()) + user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"]) if err != nil { logger.Error(err.Error()) return "", err diff --git a/pico/ssh.go b/pico/ssh.go index 6288832c..8f24e815 100644 --- a/pico/ssh.go +++ b/pico/ssh.go @@ -28,7 +28,6 @@ import ( ) func authHandler(ctx ssh.Context, key ssh.PublicKey) bool { - shared.SetPublicKey(ctx, key) return true } diff --git a/pipe/cli.go b/pipe/cli.go index 94a6f724..b71a2510 100644 --- a/pipe/cli.go +++ b/pipe/cli.go @@ -141,7 +141,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware { logger := handler.Logger ctx := sesh.Context() - user, err := shared.GetUser(ctx) + user, err := handler.DBPool.FindUser(sesh.Permissions().Extensions["user_id"]) if err != nil { logger.Info("user not found", "err", err) } diff --git a/shared/ssh.go b/shared/ssh.go index b30a2ef5..f0270e75 100644 --- a/shared/ssh.go +++ b/shared/ssh.go @@ -1,7 +1,6 @@ package shared import ( - "fmt" "log/slog" "github.com/charmbracelet/ssh" @@ -9,47 +8,6 @@ import ( "github.com/picosh/utils" ) -type ctxUserKey struct{} -type ctxFeatureFlagKey struct{} - -func GetUser(ctx ssh.Context) (*db.User, error) { - user, ok := ctx.Value(ctxUserKey{}).(*db.User) - if !ok { - return user, fmt.Errorf("user not set on `ssh.Context()` for connection") - } - return user, nil -} - -func SetUser(ctx ssh.Context, user *db.User) { - ctx.SetValue(ctxUserKey{}, user) -} - -func GetFeatureFlag(ctx ssh.Context) (*db.FeatureFlag, error) { - ff, ok := ctx.Value(ctxFeatureFlagKey{}).(*db.FeatureFlag) - if !ok || ff.Name == "" { - return ff, fmt.Errorf("feature flag not set on `ssh.Context()` for connection") - } - return ff, nil -} - -func SetFeatureFlag(ctx ssh.Context, ff *db.FeatureFlag) { - ctx.SetValue(ctxFeatureFlagKey{}, ff) -} - -type ctxPublicKey struct{} - -func GetPublicKey(ctx ssh.Context) (ssh.PublicKey, error) { - pk, ok := ctx.Value(ctxPublicKey{}).(ssh.PublicKey) - if !ok { - return nil, fmt.Errorf("public key not set on `ssh.Context()` for connection") - } - return pk, nil -} - -func SetPublicKey(ctx ssh.Context, pk ssh.PublicKey) { - ctx.SetValue(ctxPublicKey{}, pk) -} - type SshAuthHandler struct { DBPool db.DB Logger *slog.Logger @@ -64,11 +22,28 @@ func NewSshAuthHandler(dbpool db.DB, logger *slog.Logger, cfg *ConfigSite) *SshA } } -func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool { - SetPublicKey(ctx, key) +func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag { + ff, _ := dbpool.FindFeatureForUser(userID, "plus") + // we have free tiers so users might not have a feature flag + // in which case we set sane defaults + if ff == nil { + ff = db.NewFeatureFlag( + userID, + "plus", + cfg.MaxSize, + cfg.MaxAssetSize, + cfg.MaxSpecialFileSize, + ) + } + // this is jank + ff.Data.StorageMax = ff.FindStorageMax(cfg.MaxSize) + ff.Data.FileMax = ff.FindFileMax(cfg.MaxAssetSize) + ff.Data.SpecialFileMax = ff.FindSpecialFileMax(cfg.MaxSpecialFileSize) + return ff +} +func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool { pubkey := utils.KeyForKeyText(key) - user, err := r.DBPool.FindUserForKey(ctx.User(), pubkey) if err != nil { r.Logger.Error( @@ -84,24 +59,10 @@ func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) b return false } - ff, _ := r.DBPool.FindFeatureForUser(user.ID, "plus") - // we have free tiers so users might not have a feature flag - // in which case we set sane defaults - if ff == nil { - ff = db.NewFeatureFlag( - user.ID, - "plus", - r.Cfg.MaxSize, - r.Cfg.MaxAssetSize, - r.Cfg.MaxSpecialFileSize, - ) + if ctx.Permissions().Extensions == nil { + ctx.Permissions().Extensions = map[string]string{} } - // this is jank - ff.Data.StorageMax = ff.FindStorageMax(r.Cfg.MaxSize) - ff.Data.FileMax = ff.FindFileMax(r.Cfg.MaxAssetSize) - ff.Data.SpecialFileMax = ff.FindSpecialFileMax(r.Cfg.MaxSpecialFileSize) - - SetUser(ctx, user) - SetFeatureFlag(ctx, ff) + ctx.Permissions().Extensions["user_id"] = user.ID + ctx.Permissions().Extensions["pubkey"] = pubkey return true }