From 0d3c451b3661fdfe2d43502ecdbb21c380d440fc Mon Sep 17 00:00:00 2001 From: vimiix Date: Fri, 21 Jun 2024 12:30:40 +0800 Subject: [PATCH] refactor: merge file module to utils --- cmd/ssx/cmd/upgrade.go | 13 +++++-------- internal/file/file.go | 34 ---------------------------------- internal/utils/utils.go | 24 +++++++++++++++++++++--- ssx/entry/entry.go | 31 +++++++++++++++++-------------- 4 files changed, 43 insertions(+), 59 deletions(-) delete mode 100644 internal/file/file.go diff --git a/cmd/ssx/cmd/upgrade.go b/cmd/ssx/cmd/upgrade.go index 1e23a2a..fb46efe 100644 --- a/cmd/ssx/cmd/upgrade.go +++ b/cmd/ssx/cmd/upgrade.go @@ -3,7 +3,6 @@ package cmd import ( "context" "fmt" - "github.com/vimiix/ssx/ssx/version" "io" "net/http" "os" @@ -11,10 +10,11 @@ import ( "runtime" "strings" + "github.com/vimiix/ssx/ssx/version" + "github.com/pkg/errors" "github.com/spf13/cobra" "github.com/tidwall/gjson" - "github.com/vimiix/ssx/internal/file" "github.com/vimiix/ssx/internal/lg" "github.com/vimiix/ssx/internal/utils" ) @@ -84,7 +84,7 @@ func upgrade(ctx context.Context, opt *upgradeOpt) error { return err } } else { - if !file.IsExist(opt.PkgPath) { + if !utils.FileExists(opt.PkgPath) { return errors.Errorf("file not found: %s", opt.PkgPath) } localPkg = opt.PkgPath @@ -131,7 +131,7 @@ func upgrade(ctx context.Context, opt *upgradeOpt) error { return err } newBin := filepath.Join(tempDir, "ssx") - if !file.IsExist(newBin) { + if !utils.FileExists(newBin) { return errors.New("not found ssx binary after extracting package") } execPath, err := os.Executable() @@ -183,9 +183,6 @@ func replaceBinary(newBin string, oldBin string) error { if err := os.Link(oldBin, bakBin); err != nil { return err } - // if err := file.CopyFile(oldBin, bakName, 0700); err != nil { - // return err - // } lg.Debug("remove old binary") if err := os.RemoveAll(oldBin); err != nil { @@ -193,7 +190,7 @@ func replaceBinary(newBin string, oldBin string) error { } lg.Debug("make the new binary effective") - if err := file.CopyFile(newBin, oldBin, 0700); err != nil { + if err := utils.CopyFile(newBin, oldBin, 0700); err != nil { _ = os.RemoveAll(oldBin) renameErr := os.Rename(bakBin, oldBin) if renameErr != nil { diff --git a/internal/file/file.go b/internal/file/file.go deleted file mode 100644 index 36c9517..0000000 --- a/internal/file/file.go +++ /dev/null @@ -1,34 +0,0 @@ -package file - -import ( - "io" - "os" -) - -// CopyFile copies the contents of src to dst -func CopyFile(src, dst string, perm os.FileMode) error { - sf, err := os.Open(src) - if err != nil { - return err - } - defer sf.Close() - tf, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, perm) - if err != nil { - return err - } - defer tf.Close() - _, err = io.Copy(tf, sf) - return err -} - -// IsExist check given path if exists -func IsExist(path string) bool { - if path == "" { - return false - } - _, err := os.Stat(path) - if err != nil { - return os.IsExist(err) - } - return true -} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 6e8edeb..e94038e 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -16,14 +16,16 @@ import ( "github.com/denisbrodbeck/machineid" "github.com/pkg/errors" - "github.com/vimiix/ssx/internal/file" "github.com/vimiix/ssx/internal/lg" "github.com/vimiix/ssx/ssx/env" ) // FileExists check given filename if exists func FileExists(filename string) bool { - _, err := os.Stat(filename) + if filename == "" { + return false + } + _, err := os.Stat(ExpandHomeDir(filename)) return !os.IsNotExist(err) } @@ -235,7 +237,7 @@ func Untar(tarPath string, targetDir string, filenames ...string) error { } dirpath := path.Dir(target) - if !file.IsExist(dirpath) { + if !FileExists(dirpath) { if err := os.MkdirAll(dirpath, 0700); err != nil { return err } @@ -255,3 +257,19 @@ func Untar(tarPath string, targetDir string, filenames ...string) error { } } } + +// CopyFile copies the contents of src to dst +func CopyFile(src, dst string, perm os.FileMode) error { + sf, err := os.Open(src) + if err != nil { + return err + } + defer sf.Close() + tf, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, perm) + if err != nil { + return err + } + defer tf.Close() + _, err = io.Copy(tf, sf) + return err +} diff --git a/ssx/entry/entry.go b/ssx/entry/entry.go index f7125bc..d226bfa 100644 --- a/ssx/entry/entry.go +++ b/ssx/entry/entry.go @@ -15,7 +15,6 @@ import ( "github.com/skeema/knownhosts" "golang.org/x/crypto/ssh" - "github.com/vimiix/ssx/internal/file" "github.com/vimiix/ssx/internal/lg" "github.com/vimiix/ssx/internal/terminal" "github.com/vimiix/ssx/internal/utils" @@ -28,7 +27,7 @@ const ( ) var ( - defaultIdentityFile = utils.ExpandHomeDir("~/.ssh/id_rsa") + defaultIdentityFile = "~/.ssh/id_rsa" defaultUser = "root" defaultPort = "22" ) @@ -102,6 +101,10 @@ func (e *Entry) ClearPassword() { } } +func (e *Entry) KeyFileAbsPath() string { + return utils.ExpandHomeDir(e.KeyPath) +} + func getConnectTimeout() time.Duration { var defaultTimeout = time.Second * 10 val := os.Getenv(env.SSXConnectTimeout) @@ -181,7 +184,7 @@ func (e *Entry) Tidy() error { if len(e.Port) <= 0 { e.Port = defaultPort } - if e.KeyPath == "" && file.IsExist(defaultIdentityFile) { + if e.KeyPath == "" && utils.FileExists(defaultIdentityFile) { e.KeyPath = defaultIdentityFile } if e.Proxy != nil { @@ -216,14 +219,14 @@ func passwordCallback(ctx context.Context, user, host string, storePassFunc func fmt.Printf("%s@%s's password:", user, host) bs, readErr := terminal.ReadPassword(ctx) fmt.Println() - if readErr == nil { - p := string(bs) - if storePassFunc != nil { - storePassFunc(p) - } - return p, nil + if readErr != nil { + return "", readErr } - return "", readErr + p := string(bs) + if storePassFunc != nil { + storePassFunc(p) + } + return p, nil } return ssh.PasswordCallback(prompt) } @@ -263,7 +266,7 @@ func (e *Entry) privateKeyAuthMethods(ctx context.Context) ([]ssh.AuthMethod, er } var methods []ssh.AuthMethod for _, f := range keyfiles { - if !file.IsExist(f) { + if !utils.FileExists(f) { lg.Debug("keyfile %s not found, skip", f) continue } @@ -290,9 +293,9 @@ func (e *Entry) keyfileAuth(ctx context.Context, keypath string) (ssh.AuthMethod signer, err = ssh.ParsePrivateKey(pemBytes) passphraseMissingError := &ssh.PassphraseMissingError{} if err != nil { - if keypath != e.KeyPath { + if keypath != e.KeyFileAbsPath() { lg.Debug("parse failed, ignore keyfile %q", keypath) - return nil, nil + return nil, err } if errors.As(err, &passphraseMissingError) { if e.Passphrase != "" { @@ -327,7 +330,7 @@ var defaultRSAKeyFiles = []string{ func (e *Entry) collectKeyfiles() []string { var keypaths []string if e.KeyPath != "" && utils.FileExists(e.KeyPath) { - keypaths = append(keypaths, e.KeyPath) + keypaths = append(keypaths, e.KeyFileAbsPath()) } u, err := user.Current() if err != nil {