diff --git a/Makefile b/Makefile index 49f0bc5..8153bd3 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ test: ## run all unit tests $(GO) test -gcflags=all=-l $(TEST_FILES) -coverprofile dist/cov.out -covermode count .PHONY: ssx -ssx: lint ## build ssx binary +ssx: ## build ssx binary $(GO) build -ldflags '$(LDFLAGS)' -gcflags '-N -l' -o dist/ssx ./cmd/ssx/main.go .PHONY: tag diff --git a/cmd/ssx/cmd/root.go b/cmd/ssx/cmd/root.go index f27640e..0e5725c 100644 --- a/cmd/ssx/cmd/root.go +++ b/cmd/ssx/cmd/root.go @@ -89,6 +89,7 @@ ssx 100 pwd`, root.AddCommand(newDeleteCmd()) root.AddCommand(newTagCmd()) root.AddCommand(newInfoCmd()) + root.AddCommand(newUpgradeCmd()) root.CompletionOptions.HiddenDefaultCmd = true root.SetHelpCommand(&cobra.Command{Hidden: true}) diff --git a/cmd/ssx/cmd/upgrade.go b/cmd/ssx/cmd/upgrade.go new file mode 100644 index 0000000..8f049ca --- /dev/null +++ b/cmd/ssx/cmd/upgrade.go @@ -0,0 +1,188 @@ +package cmd + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + "github.com/tidwall/gjson" + "github.com/vimiix/ssx/internal/file" + "github.com/vimiix/ssx/internal/lg" + "github.com/vimiix/ssx/internal/utils" +) + +const ( + GITHUB_LATEST_API = "https://api.github.com/repos/vimiix/ssx/releases/latest" + GITHUB_PKG_FMT = "https://github.com/vimiix/ssx/releases/download/v{VERSION}/ssx_v{VERSION}_{OS}_{ARCH}.tar.gz" +) + +type upgradeOpt struct { + PkgPath string + Version string +} + +func newUpgradeCmd() *cobra.Command { + opt := &upgradeOpt{} + cmd := &cobra.Command{ + Use: "upgrade", + Short: "upgrade ssx version", + Example: `# Upgrade online +ssx upgrade [] + +# Upgrade with local filepath or specify new package URL path +ssx upgrade -p + +# If both version and package path are specified, +# ssx prefer to use package path.`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + opt.Version = args[0] + } + return upgrade(cmd.Context(), opt) + }} + cmd.Flags().StringVarP(&opt.PkgPath, "package", "p", "", "new package file or URL path") + return cmd +} + +func unifyArch() (string, error) { + switch runtime.GOARCH { + case "amd64", "x86_64": + return "x86_64", nil + case "arm64", "aarch64": + return "arm64", nil + default: + return "", errors.Errorf("not supported architecture: %s", runtime.GOARCH) + } +} + +func upgrade(ctx context.Context, opt *upgradeOpt) error { + tempDir, err := os.MkdirTemp("", "*") + if err != nil { + return err + } + lg.Debug("make temp dir: %s", tempDir) + defer os.RemoveAll(tempDir) + var localPkg string + if opt.PkgPath != "" { + if strings.Contains(opt.PkgPath, "://") { + localPkg = filepath.Join(tempDir, "ssx.tar.gz") + lg.Info("downloading package from %s", opt.PkgPath) + if err := utils.DownloadFile(ctx, opt.PkgPath, localPkg); err != nil { + return err + } + } else { + if !file.IsExist(opt.PkgPath) { + return errors.Errorf("file not found: %s", opt.PkgPath) + } + localPkg = opt.PkgPath + } + } else if opt.Version != "" { + semVer := strings.TrimPrefix(opt.Version, "v") + if len(strings.Split(semVer, ".")) != 3 { + return errors.Errorf("bad version: %s", opt.Version) + } + arch, err := unifyArch() + if err != nil { + return err + } + replacer := strings.NewReplacer("{VERSION}", semVer, "{OS}", runtime.GOOS, "{ARCH}", arch) + urlStr := replacer.Replace(GITHUB_PKG_FMT) + localPkg = filepath.Join(tempDir, "ssx.tar.gz") + lg.Info("downloading package from %s", urlStr) + if err := utils.DownloadFile(ctx, urlStr, localPkg); err != nil { + return err + } + } else { + lg.Info("detecting latest package url") + urlStr, err := getLatestPkgURL() + if err != nil { + return err + } + if urlStr == "" { + return errors.New("failed to get latest package url") + } + localPkg = filepath.Join(tempDir, "ssx.tar.gz") + lg.Info("downloading latest package from %s", urlStr) + if err := utils.DownloadFile(ctx, urlStr, localPkg); err != nil { + return err + } + } + lg.Info("extracting package") + if err := utils.Untar(localPkg, tempDir); err != nil { + return err + } + newBin := filepath.Join(tempDir, "ssx") + if !file.IsExist(newBin) { + return errors.New("not found ssx binary after extracting package") + } + execPath, err := os.Executable() + if err != nil { + return err + } + execAbsPath, err := filepath.Abs(execPath) + if err != nil { + return err + } + lg.Info("replacing old binary with new binary") + if err := replaceBinary(newBin, execAbsPath); err != nil { + return err + } + lg.Info("upgrade success") + return nil +} + +func getLatestPkgURL() (string, error) { + arch, err := unifyArch() + if err != nil { + return "", err + } + r, err := http.Get(GITHUB_LATEST_API) + if err != nil { + return "", err + } + defer r.Body.Close() + jsonBody, err := io.ReadAll(r.Body) + if err != nil { + return "", err + } + + res := gjson.Get(string(jsonBody), + fmt.Sprintf(`assets.#(name%%"*%s_%s.tar.gz").browser_download_url`, runtime.GOOS, arch)) + return res.String(), nil +} + +func replaceBinary(newBin string, oldBin string) error { + bakBin := oldBin + ".bak" + lg.Debug("backup old binary from %s to %s", oldBin, bakBin) + if err := os.Link(oldBin, bakBin); err != nil { + return err + } + // if err := file.CopyFile(oldBin, bakName, 0700); err != nil { + // return err + // } + + lg.Debug("remove old binary") + if err := os.RemoveAll(oldBin); err != nil { + return err + } + + lg.Debug("make the new binary effective") + if err := file.CopyFile(newBin, oldBin, 0700); err != nil { + _ = os.RemoveAll(oldBin) + renameErr := os.Rename(bakBin, oldBin) + if renameErr != nil { + lg.Warn("restore old binary failed, please rename it manually\n"+ + " mv %s %s", bakBin, oldBin) + } + return err + } + _ = os.RemoveAll(bakBin) + return nil +} diff --git a/go.mod b/go.mod index e5aa321..904bae7 100644 --- a/go.mod +++ b/go.mod @@ -31,5 +31,8 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/tidwall/gjson v1.17.1 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e4b5372..d813abc 100644 --- a/go.sum +++ b/go.sum @@ -52,6 +52,13 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/vimiix/tablewriter v0.0.0-20231207073205-aad9e2006284 h1:7o3B9eLdW6tgtWcgP0TVnq9QAUFu5Ii/RoaFOkEMYbc= github.com/vimiix/tablewriter v0.0.0-20231207073205-aad9e2006284/go.mod h1:uQpPcEuo28DE69kbtdWpMfeB+el/Kaeh2hCEdrz1iKI= go.etcd.io/bbolt v1.3.10 h1:+BqfJTcCzTItrop8mq/lbzL8wSGtj94UO/3U31shqG0= diff --git a/internal/encrypt/encrypt.go b/internal/encrypt/encrypt.go index 788a429..4dcf48b 100644 --- a/internal/encrypt/encrypt.go +++ b/internal/encrypt/encrypt.go @@ -1,5 +1,3 @@ -// Copyright 2022 Enmotech Inc. All rights reserved. - package encrypt import ( diff --git a/internal/file/file.go b/internal/file/file.go new file mode 100644 index 0000000..36c9517 --- /dev/null +++ b/internal/file/file.go @@ -0,0 +1,34 @@ +package file + +import ( + "io" + "os" +) + +// CopyFile copies the contents of src to dst +func CopyFile(src, dst string, perm os.FileMode) error { + sf, err := os.Open(src) + if err != nil { + return err + } + defer sf.Close() + tf, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, perm) + if err != nil { + return err + } + defer tf.Close() + _, err = io.Copy(tf, sf) + return err +} + +// IsExist check given path if exists +func IsExist(path string) bool { + if path == "" { + return false + } + _, err := os.Stat(path) + if err != nil { + return os.IsExist(err) + } + return true +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index acf075e..a481713 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -1,14 +1,22 @@ package utils import ( + "archive/tar" + "compress/gzip" + "context" + "io" + "net/http" + "net/url" "os" "os/user" + "path" "path/filepath" "regexp" "strings" "github.com/denisbrodbeck/machineid" "github.com/pkg/errors" + "github.com/vimiix/ssx/internal/file" "github.com/vimiix/ssx/ssx/env" ) @@ -107,3 +115,137 @@ func GetSecretKey() (string, error) { } return to16chars(machineID), nil } + +func DownloadFile(ctx context.Context, urlStr string, saveFile string) error { + _, err := url.Parse(urlStr) + if err != nil { + return err + } + fp, err := os.Create(saveFile) + if err != nil { + return err + } + defer fp.Close() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return errors.Errorf("request failed:\n- url: %s\n- response: %s", urlStr, resp.Status) + } + _, err = io.Copy(fp, resp.Body) + if err != nil { + return err + } + return nil +} + +type closeFunc func() + +func openTarball(tarball string) (*tar.Reader, closeFunc, error) { + if tarball == "" { + return nil, nil, errors.Errorf("no tarball specified") + } + + f, err := os.Open(tarball) + if err != nil { + return nil, nil, err + } + closers := []io.Closer{f} + var tr *tar.Reader + gr, err := gzip.NewReader(f) + if err != nil { + f.Close() + return nil, nil, err + } + closers = append(closers, gr) + tr = tar.NewReader(gr) + closeFunc := func() { + for i := len(closers) - 1; i > -1; i-- { + closers[i].Close() + } + } + return tr, closeFunc, nil +} + +func Untar(tarPath string, targetDir string, filenames ...string) error { + specifiedUntar := false + if len(filenames) > 0 { + specifiedUntar = true + } + + tr, closefunc, err := openTarball(tarPath) + if err != nil { + return err + } + defer closefunc() + + for { + header, err := tr.Next() + switch { + // if no more files are found return + case err == io.EOF: + return nil + // return any other error + case err != nil: + return err + // if the header is nil, just skip it (not sure how this happens) + case header == nil: + continue + } + + // the target location where the dir/file should be created + target := filepath.Join(targetDir, filepath.FromSlash(header.Name)) + switch header.Typeflag { + // if it's a dir, and it doesn't exist create it + case tar.TypeDir: + if _, err := os.Stat(target); err != nil { + if err := os.MkdirAll(target, 0700); err != nil { + return err + } + } + // if it's a file create it + case tar.TypeReg: + if specifiedUntar { + if len(filenames) == 0 { + // 指定要解压的文件都已经找到,应立即返回 + return nil + } + targetIdx := -1 + for idx, fn := range filenames { + if strings.TrimPrefix(fn, "./") == strings.TrimPrefix(header.Name, "./") { + targetIdx = idx + } + } + if targetIdx == -1 { + continue + } + filenames = append(filenames[:targetIdx], filenames[targetIdx+1:]...) + } + + dirpath := path.Dir(target) + if !file.IsExist(dirpath) { + if err := os.MkdirAll(dirpath, 0700); err != nil { + return err + } + } + + targetFile, err := os.OpenFile(target, os.O_CREATE|os.O_RDWR|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return err + } + // copy over contents + if _, err := io.Copy(targetFile, tr); err != nil { + return err + } + // manually close here after each file operation; defering would cause each file close + // to wait until all operations have completed. + targetFile.Close() + } + } +} diff --git a/ssx/client.go b/ssx/client.go index 7813f96..2d00db8 100644 --- a/ssx/client.go +++ b/ssx/client.go @@ -82,7 +82,6 @@ func (c *Client) Interact(ctx context.Context) error { lg.Info("connected server %s, version: %s", c.entry.String(), string(c.cli.ServerVersion())) - session, err := c.cli.NewSession() if err != nil { return err