Skip to content

Commit

Permalink
refactor: merge file module to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
vimiix committed Jun 21, 2024
1 parent 405ec5f commit 0d3c451
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 59 deletions.
13 changes: 5 additions & 8 deletions cmd/ssx/cmd/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ package cmd
import (
"context"
"fmt"
"github.com/vimiix/ssx/ssx/version"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"

"github.com/vimiix/ssx/ssx/version"

"github.com/pkg/errors"
"github.com/spf13/cobra"
"github.com/tidwall/gjson"
"github.com/vimiix/ssx/internal/file"
"github.com/vimiix/ssx/internal/lg"
"github.com/vimiix/ssx/internal/utils"
)
Expand Down Expand Up @@ -84,7 +84,7 @@ func upgrade(ctx context.Context, opt *upgradeOpt) error {
return err
}
} else {
if !file.IsExist(opt.PkgPath) {
if !utils.FileExists(opt.PkgPath) {
return errors.Errorf("file not found: %s", opt.PkgPath)
}
localPkg = opt.PkgPath
Expand Down Expand Up @@ -131,7 +131,7 @@ func upgrade(ctx context.Context, opt *upgradeOpt) error {
return err
}
newBin := filepath.Join(tempDir, "ssx")
if !file.IsExist(newBin) {
if !utils.FileExists(newBin) {
return errors.New("not found ssx binary after extracting package")
}
execPath, err := os.Executable()
Expand Down Expand Up @@ -183,17 +183,14 @@ func replaceBinary(newBin string, oldBin string) error {
if err := os.Link(oldBin, bakBin); err != nil {
return err
}
// if err := file.CopyFile(oldBin, bakName, 0700); err != nil {
// return err
// }

lg.Debug("remove old binary")
if err := os.RemoveAll(oldBin); err != nil {
return err
}

lg.Debug("make the new binary effective")
if err := file.CopyFile(newBin, oldBin, 0700); err != nil {
if err := utils.CopyFile(newBin, oldBin, 0700); err != nil {
_ = os.RemoveAll(oldBin)
renameErr := os.Rename(bakBin, oldBin)
if renameErr != nil {
Expand Down
34 changes: 0 additions & 34 deletions internal/file/file.go

This file was deleted.

24 changes: 21 additions & 3 deletions internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ import (

"github.com/denisbrodbeck/machineid"
"github.com/pkg/errors"
"github.com/vimiix/ssx/internal/file"
"github.com/vimiix/ssx/internal/lg"
"github.com/vimiix/ssx/ssx/env"
)

// FileExists check given filename if exists
func FileExists(filename string) bool {
_, err := os.Stat(filename)
if filename == "" {
return false
}
_, err := os.Stat(ExpandHomeDir(filename))
return !os.IsNotExist(err)
}

Expand Down Expand Up @@ -235,7 +237,7 @@ func Untar(tarPath string, targetDir string, filenames ...string) error {
}

dirpath := path.Dir(target)
if !file.IsExist(dirpath) {
if !FileExists(dirpath) {
if err := os.MkdirAll(dirpath, 0700); err != nil {
return err
}
Expand All @@ -255,3 +257,19 @@ func Untar(tarPath string, targetDir string, filenames ...string) error {
}
}
}

// CopyFile copies the contents of src to dst
func CopyFile(src, dst string, perm os.FileMode) error {
sf, err := os.Open(src)
if err != nil {
return err
}
defer sf.Close()
tf, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, perm)
if err != nil {
return err
}
defer tf.Close()
_, err = io.Copy(tf, sf)
return err
}
31 changes: 17 additions & 14 deletions ssx/entry/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/skeema/knownhosts"
"golang.org/x/crypto/ssh"

"github.com/vimiix/ssx/internal/file"
"github.com/vimiix/ssx/internal/lg"
"github.com/vimiix/ssx/internal/terminal"
"github.com/vimiix/ssx/internal/utils"
Expand All @@ -28,7 +27,7 @@ const (
)

var (
defaultIdentityFile = utils.ExpandHomeDir("~/.ssh/id_rsa")
defaultIdentityFile = "~/.ssh/id_rsa"
defaultUser = "root"
defaultPort = "22"
)
Expand Down Expand Up @@ -102,6 +101,10 @@ func (e *Entry) ClearPassword() {
}
}

func (e *Entry) KeyFileAbsPath() string {
return utils.ExpandHomeDir(e.KeyPath)
}

func getConnectTimeout() time.Duration {
var defaultTimeout = time.Second * 10
val := os.Getenv(env.SSXConnectTimeout)
Expand Down Expand Up @@ -181,7 +184,7 @@ func (e *Entry) Tidy() error {
if len(e.Port) <= 0 {
e.Port = defaultPort
}
if e.KeyPath == "" && file.IsExist(defaultIdentityFile) {
if e.KeyPath == "" && utils.FileExists(defaultIdentityFile) {
e.KeyPath = defaultIdentityFile
}
if e.Proxy != nil {
Expand Down Expand Up @@ -216,14 +219,14 @@ func passwordCallback(ctx context.Context, user, host string, storePassFunc func
fmt.Printf("%s@%s's password:", user, host)
bs, readErr := terminal.ReadPassword(ctx)
fmt.Println()
if readErr == nil {
p := string(bs)
if storePassFunc != nil {
storePassFunc(p)
}
return p, nil
if readErr != nil {
return "", readErr
}
return "", readErr
p := string(bs)
if storePassFunc != nil {
storePassFunc(p)
}
return p, nil
}
return ssh.PasswordCallback(prompt)
}
Expand Down Expand Up @@ -263,7 +266,7 @@ func (e *Entry) privateKeyAuthMethods(ctx context.Context) ([]ssh.AuthMethod, er
}
var methods []ssh.AuthMethod
for _, f := range keyfiles {
if !file.IsExist(f) {
if !utils.FileExists(f) {
lg.Debug("keyfile %s not found, skip", f)
continue
}
Expand All @@ -290,9 +293,9 @@ func (e *Entry) keyfileAuth(ctx context.Context, keypath string) (ssh.AuthMethod
signer, err = ssh.ParsePrivateKey(pemBytes)
passphraseMissingError := &ssh.PassphraseMissingError{}
if err != nil {
if keypath != e.KeyPath {
if keypath != e.KeyFileAbsPath() {
lg.Debug("parse failed, ignore keyfile %q", keypath)
return nil, nil
return nil, err
}
if errors.As(err, &passphraseMissingError) {
if e.Passphrase != "" {
Expand Down Expand Up @@ -327,7 +330,7 @@ var defaultRSAKeyFiles = []string{
func (e *Entry) collectKeyfiles() []string {
var keypaths []string
if e.KeyPath != "" && utils.FileExists(e.KeyPath) {
keypaths = append(keypaths, e.KeyPath)
keypaths = append(keypaths, e.KeyFileAbsPath())
}
u, err := user.Current()
if err != nil {
Expand Down

0 comments on commit 0d3c451

Please sign in to comment.