Skip to content

Commit

Permalink
Update request synchronization
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniomika committed Oct 26, 2020
1 parent 804e8be commit 4f7ef48
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 69 deletions.
2 changes: 1 addition & 1 deletion sshmuxer/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func handleRequests(reqs <-chan *ssh.Request, sshConn *utils.SSHConnection, stat
if viper.GetBool("debug") {
log.Println("Main Request Info", req.Type, req.WantReply, string(req.Payload))
}
go handleRequest(req, sshConn, state)
handleRequest(req, sshConn, state)
}
}

Expand Down
130 changes: 62 additions & 68 deletions sshmuxer/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ type forwardedTCPPayload struct {
// handleRemoteForward will handle a remote forward request
// and stand up the relevant listeners.
func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection, state *utils.State) {
cleanedUp := sync.Once{}

sshConn.SetupLock.Lock()
defer cleanedUp.Do(sshConn.SetupLock.Unlock)

cleanupOnce := &sync.Once{}
check := &channelForwardMsg{}

err := ssh.Unmarshal(newRequest.Payload, check)
Expand Down Expand Up @@ -97,18 +93,19 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection,
state.Listeners.Store(listenAddr, listenerHolder)
sshConn.Listeners.Store(listenAddr, listenerHolder)

deferHandler := func() {}

cleanupChanListener := func() {
listenerHolder.Close()
state.Listeners.Delete(listenAddr)
sshConn.Listeners.Delete(listenAddr)
os.Remove(listenAddr)
deferHandler()
}

defer cleanupChanListener()

go func() {
<-sshConn.Close
cleanupChanListener()
cleanupOnce.Do(cleanupChanListener)
}()

connType := "tcp"
Expand All @@ -135,10 +132,10 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection,

mainRequestMessages = requestMessages

defer func() {
deferHandler = func() {
err := pH.Balancer.RemoveServer(serverURL)
if err != nil {
log.Println("Unable to add server to balancer")
log.Println("Unable to remove server from balancer:", err)
}

pH.SSHConnections.Delete(listenerHolder.Addr().String())
Expand All @@ -150,7 +147,7 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection,
state.Console.RemoveRoute(host)
}
}
}()
}
case utils.AliasListener:
aH, serverURL, validAlias, requestMessages, err := handleAliasListener(check, stringPort, mainRequestMessages, listenerHolder, state, sshConn)
if err != nil {
Expand All @@ -163,18 +160,18 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection,

mainRequestMessages = requestMessages

defer func() {
deferHandler = func() {
err := aH.Balancer.RemoveServer(serverURL)
if err != nil {
log.Println("Unable to add server to balancer")
log.Println("Unable to remove server from balancer:", err)
}

aH.SSHConnections.Delete(listenerHolder.Addr().String())

if len(aH.Balancer.Servers()) == 0 {
state.AliasListeners.Delete(validAlias)
}
}()
}
case utils.TCPListener:
tH, serverURL, tcpAddr, requestMessages, err := handleTCPListener(check, bindPort, mainRequestMessages, listenerHolder, state, sshConn)
if err != nil {
Expand All @@ -191,10 +188,10 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection,

go tH.Handle(state)

defer func() {
deferHandler = func() {
err := tH.Balancer.RemoveServer(serverURL)
if err != nil {
log.Println("Unable to add server to balancer")
log.Println("Unable to remove server from balancer:", err)
}

tH.SSHConnections.Delete(listenerHolder.Addr().String())
Expand All @@ -204,75 +201,72 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection,
state.Listeners.Delete(tcpAddr)
state.TCPListeners.Delete(tcpAddr)
}
}()
}
}

var replyMessage []byte

if check.Rport == 0 {
replyMessage = ssh.Marshal(portChannelForwardReplyPayload)
} else {
if check.Rport != 0 {
portChannelForwardReplyPayload.Rport = check.Rport
}

err = newRequest.Reply(true, replyMessage)
err = newRequest.Reply(true, ssh.Marshal(portChannelForwardReplyPayload))
if err != nil {
log.Println("Error replying to port forwarding request:", err)
return
}

sshConn.SendMessage(mainRequestMessages, false)
sshConn.SendMessage(mainRequestMessages, true)

cleanedUp.Do(sshConn.SetupLock.Unlock)
go func() {
defer cleanupOnce.Do(cleanupChanListener)
for {
cl, err := listenerHolder.Accept()
if err != nil {
break
}

for {
cl, err := listenerHolder.Accept()
if err != nil {
break
}
resp := &forwardedTCPPayload{
Addr: check.Addr,
Port: portChannelForwardReplyPayload.Rport,
OriginAddr: check.Addr,
OriginPort: portChannelForwardReplyPayload.Rport,
}

resp := &forwardedTCPPayload{
Addr: check.Addr,
Port: portChannelForwardReplyPayload.Rport,
OriginAddr: check.Addr,
OriginPort: portChannelForwardReplyPayload.Rport,
}
newChan, newReqs, err := sshConn.SSHConn.OpenChannel("forwarded-tcpip", ssh.Marshal(resp))
if err != nil {
sshConn.SendMessage(err.Error(), true)
cl.Close()
continue
}

newChan, newReqs, err := sshConn.SSHConn.OpenChannel("forwarded-tcpip", ssh.Marshal(resp))
if err != nil {
sshConn.SendMessage(err.Error(), true)
cl.Close()
continue
}
if sshConn.ProxyProto != 0 && listenerType == utils.TCPListener {
var sourceInfo *net.TCPAddr
var destInfo *net.TCPAddr
if _, ok := cl.RemoteAddr().(*net.TCPAddr); !ok {
sourceInfo = sshConn.SSHConn.RemoteAddr().(*net.TCPAddr)
destInfo = sshConn.SSHConn.LocalAddr().(*net.TCPAddr)
} else {
sourceInfo = cl.RemoteAddr().(*net.TCPAddr)
destInfo = cl.LocalAddr().(*net.TCPAddr)
}

if sshConn.ProxyProto != 0 && listenerType == utils.TCPListener {
var sourceInfo *net.TCPAddr
var destInfo *net.TCPAddr
if _, ok := cl.RemoteAddr().(*net.TCPAddr); !ok {
sourceInfo = sshConn.SSHConn.RemoteAddr().(*net.TCPAddr)
destInfo = sshConn.SSHConn.LocalAddr().(*net.TCPAddr)
} else {
sourceInfo = cl.RemoteAddr().(*net.TCPAddr)
destInfo = cl.LocalAddr().(*net.TCPAddr)
}
proxyProtoHeader := proxyproto.Header{
Version: sshConn.ProxyProto,
Command: proxyproto.ProtocolVersionAndCommand(proxyproto.PROXY),
TransportProtocol: proxyproto.AddressFamilyAndProtocol(proxyproto.TCPv4),
SourceAddress: sourceInfo.IP,
DestinationAddress: destInfo.IP,
SourcePort: uint16(sourceInfo.Port),
DestinationPort: uint16(destInfo.Port),
}

proxyProtoHeader := proxyproto.Header{
Version: sshConn.ProxyProto,
Command: proxyproto.ProtocolVersionAndCommand(proxyproto.PROXY),
TransportProtocol: proxyproto.AddressFamilyAndProtocol(proxyproto.TCPv4),
SourceAddress: sourceInfo.IP,
DestinationAddress: destInfo.IP,
SourcePort: uint16(sourceInfo.Port),
DestinationPort: uint16(destInfo.Port),
_, err := proxyProtoHeader.WriteTo(newChan)
if err != nil && viper.GetBool("debug") {
log.Println("Error writing to channel:", err)
}
}

_, err := proxyProtoHeader.WriteTo(newChan)
if err != nil && viper.GetBool("debug") {
log.Println("Error writing to channel:", err)
}
go utils.CopyBoth(cl, newChan)
go ssh.DiscardRequests(newReqs)
}

go utils.CopyBoth(cl, newChan)
go ssh.DiscardRequests(newReqs)
}
}()
}

0 comments on commit 4f7ef48

Please sign in to comment.