diff --git a/internal/terminal/terminal.go b/internal/terminal/terminal.go new file mode 100644 index 0000000..30cb421 --- /dev/null +++ b/internal/terminal/terminal.go @@ -0,0 +1,35 @@ +package terminal + +import ( + "context" + + "github.com/containerd/console" +) + +func ReadPassword(ctx context.Context) ([]byte, error) { + c := console.Current() + defer func() { + _ = c.Reset() + }() + + var ( + errch = make(chan error, 1) + password []byte + ) + + go func() { + bs, readErr := readPassword() + if readErr != nil { + errch <- readErr + } + password = bs + errch <- nil + }() + + select { + case err := <-errch: + return password, err + case <-ctx.Done(): + return nil, ctx.Err() + } +} diff --git a/internal/terminal/terminal_unix.go b/internal/terminal/terminal_unix.go index 47ce809..dd7e187 100644 --- a/internal/terminal/terminal_unix.go +++ b/internal/terminal/terminal_unix.go @@ -14,7 +14,7 @@ import ( "github.com/vimiix/ssx/internal/lg" ) -func ReadPassword() ([]byte, error) { +func readPassword() ([]byte, error) { return term.ReadPassword(syscall.Stdin) } diff --git a/internal/terminal/terminal_windows.go b/internal/terminal/terminal_windows.go index 4be6a44..d8fa488 100644 --- a/internal/terminal/terminal_windows.go +++ b/internal/terminal/terminal_windows.go @@ -13,7 +13,7 @@ import ( "github.com/vimiix/ssx/internal/lg" ) -func ReadPassword() ([]byte, error) { +func readPassword() ([]byte, error) { return term.ReadPassword(int(windows.Stdin)) } diff --git a/ssx/client.go b/ssx/client.go index 4d0970e..410cf0a 100644 --- a/ssx/client.go +++ b/ssx/client.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/containerd/console" + "github.com/vimiix/ssx/internal/lg" "github.com/vimiix/ssx/internal/terminal" "github.com/vimiix/ssx/ssx/entry" @@ -153,7 +154,7 @@ func dialContext(ctx context.Context, network, addr string, config *ssh.ClientCo func (c *Client) login(ctx context.Context) error { network := "tcp" addr := net.JoinHostPort(c.entry.Host, c.entry.Port) - clientConfig, err := c.entry.GenSSHConfig() + clientConfig, err := c.entry.GenSSHConfig(ctx) if err != nil { return err } @@ -166,13 +167,13 @@ func (c *Client) login(ctx context.Context) error { if strings.Contains(err.Error(), "no supported methods remain") { fmt.Printf("%s@%s's password:", c.entry.User, c.entry.Host) - bs, readErr := terminal.ReadPassword() + bs, readErr := terminal.ReadPassword(ctx) + fmt.Println() if readErr == nil { p := string(bs) if p != "" { clientConfig.Auth = []ssh.AuthMethod{ssh.Password(p)} } - fmt.Println() cli, err = ssh.Dial(network, addr, clientConfig) if err == nil { c.entry.Password = p diff --git a/ssx/entry/entry.go b/ssx/entry/entry.go index 31b7200..8d3ff24 100644 --- a/ssx/entry/entry.go +++ b/ssx/entry/entry.go @@ -2,6 +2,7 @@ package entry import ( "bufio" + "context" "fmt" "net" "os" @@ -61,14 +62,14 @@ func getConnectTimeout() time.Duration { return d } -func (e *Entry) GenSSHConfig() (*ssh.ClientConfig, error) { +func (e *Entry) GenSSHConfig(ctx context.Context) (*ssh.ClientConfig, error) { cb, err := e.sshHostKeyCallback() if err != nil { return nil, err } cfg := &ssh.ClientConfig{ User: e.User, - Auth: e.AuthMethods(), + Auth: e.AuthMethods(ctx), HostKeyCallback: cb, Timeout: getConnectTimeout(), } @@ -133,7 +134,7 @@ func (e *Entry) Tidy() error { } // AuthMethods all possible auth methods -func (e *Entry) AuthMethods() []ssh.AuthMethod { +func (e *Entry) AuthMethods(ctx context.Context) []ssh.AuthMethod { var authMethods []ssh.AuthMethod // password auth if e.Password != "" { @@ -146,11 +147,11 @@ func (e *Entry) AuthMethods() []ssh.AuthMethod { authMethods = append(authMethods, keyfileAuths...) } - authMethods = append(authMethods, e.interactAuth()) + authMethods = append(authMethods, e.interactAuth(ctx)) return authMethods } -func (e *Entry) interactAuth() ssh.AuthMethod { +func (e *Entry) interactAuth(ctx context.Context) ssh.AuthMethod { return ssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { answers = make([]string, 0, len(questions)) for i, q := range questions { @@ -164,7 +165,7 @@ func (e *Entry) interactAuth() ssh.AuthMethod { return nil, err } } else { - b, err := terminal.ReadPassword() + b, err := terminal.ReadPassword(ctx) if err != nil { return nil, err }