Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support connect through jump server(#38) #39

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/ssx/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ ssx 100 pwd`,
}
root.Flags().StringVarP(&opt.DBFile, "file", "f", "", "filepath to store auth data")
root.Flags().Uint64VarP(&opt.EntryID, "id", "i", 0, "entry id")
root.Flags().StringVarP(&opt.Addr, "server", "s", "", "target server address\nsupport formats: [user@]host[:port]")
root.Flags().StringVarP(&opt.Addr, "server", "s", "", "target server address\nsupport format: [user@]host[:port]")
root.Flags().StringVarP(&opt.Tag, "tag", "t", "", "search entry by tag")
root.Flags().StringVarP(&opt.IdentityFile, "keyfile", "k", "", "identity_file path")
root.Flags().StringVarP(&opt.JumpServers, "jump-server", "J", "", "jump servers, multiple jump hops may be specified separated by comma characters\nformat: [user1@]host1[:port1][,[user2@]host2[:port2]...]")
root.Flags().StringVarP(&opt.Command, "cmd", "c", "", "the command to execute\nssh connection will exit after the execution complete")
root.Flags().DurationVar(&opt.Timeout, "timeout", 0, "timeout for connecting and executing command")

Expand Down
3 changes: 2 additions & 1 deletion ssx/bbolt/bbolt.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ func (r *Repo) TouchEntry(e *entry.Entry) error {
}
defer r.close()

lg.Debug("bbolt repo: touch entry: %d", e.ID)
return r.db.Update(func(tx *bbolt.Tx) error {
b := tx.Bucket(r.entryBucket)
var bs []byte
Expand All @@ -68,6 +67,7 @@ func (r *Repo) TouchEntry(e *entry.Entry) error {
if len(bs) == 0 {
// insert
e.ID, _ = b.NextSequence()
lg.Debug("bbolt repo: touch new entry: %d", e.ID)
now := time.Now()
e.VisitCount = 1
e.CreateAt = now
Expand All @@ -78,6 +78,7 @@ func (r *Repo) TouchEntry(e *entry.Entry) error {
return err
}
e.ID = rawEntry.ID
lg.Debug("bbolt repo: update entry: %d", e.ID)
e.VisitCount = rawEntry.VisitCount + 1
e.CreateAt = rawEntry.CreateAt
e.UpdateAt = time.Now()
Expand Down
106 changes: 70 additions & 36 deletions ssx/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package ssx

import (
"context"
"fmt"
"io"
"net"
"os"
"strings"
"sync"
"time"

Expand All @@ -18,6 +16,10 @@ import (
"github.com/vimiix/ssx/ssx/entry"
)

const (
NETWORK = "tcp"
)

type Client struct {
repo Repo
entry *entry.Entry
Expand Down Expand Up @@ -166,9 +168,9 @@ func (c *Client) keepalive(ctx context.Context) {
}

// code source: https://github.com/golang/go/issues/20288#issuecomment-832033017
func dialContext(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
func dialContext(ctx context.Context, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
d := net.Dialer{Timeout: config.Timeout}
conn, err := d.DialContext(ctx, network, addr)
conn, err := d.DialContext(ctx, NETWORK, addr)
if err != nil {
return nil, err
}
Expand All @@ -179,50 +181,82 @@ func dialContext(ctx context.Context, network, addr string, config *ssh.ClientCo
return ssh.NewClient(c, chans, reqs), nil
}

func dialThroughProxy(ctx context.Context, proxy *entry.Proxy, parentProxyCli *ssh.Client, targetEntry *entry.Entry) (*ssh.Client, error) {
var err error
if parentProxyCli == nil {
proxyConfig, err := proxy.GenSSHConfig(ctx)
if err != nil {
return nil, err
}
lg.Debug("dialing proxy: %s", proxy.String())
parentProxyCli, err = dialContext(ctx, proxy.Address(), proxyConfig)
if err != nil {
lg.Debug("dial proxy %s failed: %v", proxy.String(), err)
return nil, err
}
lg.Debug("proxy client establised")
}

var (
tmpTargetAddr string
tmpTargetConfig *ssh.ClientConfig
tmpHostString string
)
if proxy.Proxy != nil {
tmpHostString = proxy.Proxy.String()
tmpTargetAddr = proxy.Proxy.Address()
tmpTargetConfig, err = proxy.Proxy.GenSSHConfig(ctx)
if err != nil {
return nil, err
}
} else {
tmpHostString = targetEntry.String()
tmpTargetAddr = targetEntry.Address()
tmpTargetConfig, err = targetEntry.GenSSHConfig(ctx)
if err != nil {
return nil, err
}
}
lg.Debug("dialing to %s", tmpHostString)
conn, err := parentProxyCli.DialContext(ctx, NETWORK, tmpTargetAddr)
if err != nil {
return nil, err
}
nc, chans, reqs, err := ssh.NewClientConn(conn, tmpTargetAddr, tmpTargetConfig)
if err != nil {
return nil, err
}
targetCli := ssh.NewClient(nc, chans, reqs)
if proxy.Proxy == nil {
return targetCli, nil
}
return dialThroughProxy(ctx, proxy.Proxy, parentProxyCli, targetEntry)
}

// Login connect remote server and touch enrty in storage
func (c *Client) Login(ctx context.Context) error {
if err := c.connect(ctx); err != nil {
lg.Debug("connecting to %s", c.entry.String())
cli, err := c.dial(ctx)
if err != nil {
return err
}
c.cli = cli
if err := c.touchEntry(c.entry); err != nil {
lg.Error("failed to touch entry: %s", err)
}
return nil
}

func (c *Client) connect(ctx context.Context) error {
network := "tcp"
addr := net.JoinHostPort(c.entry.Host, c.entry.Port)
clientConfig, err := c.entry.GenSSHConfig(ctx)
if err != nil {
return err
}
lg.Debug("connecting to %s", c.entry.String())
cli, err := dialContext(ctx, network, addr, clientConfig)
if err == nil {
c.cli = cli
return nil
func (c *Client) dial(ctx context.Context) (*ssh.Client, error) {
if c.entry.Proxy != nil {
return dialThroughProxy(ctx, c.entry.Proxy, nil, c.entry)
}

if strings.Contains(err.Error(), "no supported methods remain") {
lg.Debug("failed connect by default auth methods, try password again")
fmt.Printf("%s@%s's password:", c.entry.User, c.entry.Host)
bs, readErr := terminal.ReadPassword(ctx)
fmt.Println()
if readErr == nil {
p := string(bs)
if p != "" {
clientConfig.Auth = []ssh.AuthMethod{ssh.Password(p)}
}
cli, err = ssh.Dial(network, addr, clientConfig)
if err == nil {
c.entry.Password = p
c.cli = cli
return nil
}
}
// connect directly
sshConfig, err := c.entry.GenSSHConfig(ctx)
if err != nil {
return nil, err
}
return err
return dialContext(ctx, c.entry.Address(), sshConfig)
}

func (c *Client) close() {
Expand Down
87 changes: 58 additions & 29 deletions ssx/entry/entry.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package entry

import (
"bufio"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -30,6 +29,7 @@ const (
const (
defaultIdentityFile = "~/.ssh/id_rsa"
defaultUser = "root"
defaultPort = "22"
)

// Entry represent a target server
Expand All @@ -46,13 +46,17 @@ type Entry struct {
Source string `json:"source"` // Data source, used to distinguish that it is from ssx stored or local ssh configuration
CreateAt time.Time `json:"create_at"`
UpdateAt time.Time `json:"update_at"`
// TODO support jump server
Proxy *Proxy `json:"proxy"`
}

func (e *Entry) String() string {
return fmt.Sprintf("%s@%s:%s", e.User, e.Host, e.Port)
}

func (e *Entry) Address() string {
return net.JoinHostPort(e.Host, e.Port)
}

func (e *Entry) JSON() ([]byte, error) {
entryCopy, err := e.Copy()
if err != nil {
Expand All @@ -74,6 +78,9 @@ func (e *Entry) Copy() (*Entry, error) {
func (e *Entry) Mask() {
e.Password = utils.MaskString(e.Password)
e.Passphrase = utils.MaskString(e.Passphrase)
if e.Proxy != nil {
e.Proxy.Mask()
}
}

func getConnectTimeout() time.Duration {
Expand All @@ -91,7 +98,7 @@ func getConnectTimeout() time.Duration {
}

func (e *Entry) GenSSHConfig(ctx context.Context) (*ssh.ClientConfig, error) {
cb, err := e.sshHostKeyCallback()
cb, err := sshHostKeyCallback()
if err != nil {
return nil, err
}
Expand All @@ -109,7 +116,7 @@ func (e *Entry) GenSSHConfig(ctx context.Context) (*ssh.ClientConfig, error) {
return cfg, nil
}

func (e *Entry) sshHostKeyCallback() (ssh.HostKeyCallback, error) {
func sshHostKeyCallback() (ssh.HostKeyCallback, error) {
khPath := utils.ExpandHomeDir("~/.ssh/known_hosts")
if !utils.FileExists(khPath) {
f, err := os.OpenFile(khPath, os.O_RDWR|os.O_CREATE, 0600)
Expand Down Expand Up @@ -153,11 +160,14 @@ func (e *Entry) Tidy() error {
e.User = defaultUser
}
if len(e.Port) <= 0 {
e.Port = "22"
e.Port = defaultPort
}
if e.KeyPath == "" {
e.KeyPath = utils.ExpandHomeDir(defaultIdentityFile)
}
if e.Proxy != nil {
e.Proxy.tidy()
}
return nil
}

Expand All @@ -177,37 +187,56 @@ func (e *Entry) AuthMethods(ctx context.Context) ([]ssh.AuthMethod, error) {
if len(keyfileAuths) > 0 {
authMethods = append(authMethods, keyfileAuths...)
}

authMethods = append(authMethods, e.interactAuth(ctx))
authMethods = append(authMethods, passwordCallback(ctx, e.User, e.Host, func(password string) { e.Password = password }))
return authMethods, nil
}

func (e *Entry) interactAuth(ctx context.Context) ssh.AuthMethod {
return ssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
answers = make([]string, 0, len(questions))
for i, q := range questions {
fmt.Print(q)
if echos[i] {
scan := bufio.NewScanner(os.Stdin)
if scan.Scan() {
answers = append(answers, scan.Text())
}
if err := scan.Err(); err != nil {
return nil, err
}
} else {
b, err := terminal.ReadPassword(ctx)
if err != nil {
return nil, err
}
fmt.Println()
answers = append(answers, string(b))
func passwordCallback(ctx context.Context, user, host string, storePassFunc func(password string)) ssh.AuthMethod {
prompt := func() (string, error) {
lg.Debug("login through password callback")
fmt.Printf("%s@%s's password:", user, host)
bs, readErr := terminal.ReadPassword(ctx)
fmt.Println()
if readErr == nil {
p := string(bs)
if storePassFunc != nil {
storePassFunc(p)
}
return p, nil
}
return answers, nil
})
return "", readErr
}
return ssh.PasswordCallback(prompt)
}

// At present, I do not know how to correctly capture password information,
// so I need to write promt by myself through passwordCallback to achieve it
// func interactAuth(ctx context.Context, who string) ssh.AuthMethod {
// return ssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
// answers = make([]string, 0, len(questions))
// for i, q := range questions {
// fmt.Printf("[%s] %s", who, q)
// if echos[i] {
// scan := bufio.NewScanner(os.Stdin)
// if scan.Scan() {
// answers = append(answers, scan.Text())
// }
// if err := scan.Err(); err != nil {
// return nil, err
// }
// } else {
// b, err := terminal.ReadPassword(ctx)
// if err != nil {
// return nil, err
// }
// fmt.Println()
// answers = append(answers, string(b))
// }
// }
// return answers, nil
// })
// }

func (e *Entry) privateKeyAuthMethods(ctx context.Context) ([]ssh.AuthMethod, error) {
keyfiles := e.collectKeyfiles()
if len(keyfiles) == 0 {
Expand Down
Loading