Skip to content

Commit

Permalink
Merge pull request ghostunnel#170 from square/cs/split-proxy
Browse files Browse the repository at this point in the history
Split out proxy logic into separate package
  • Loading branch information
csstaub authored Jun 22, 2018
2 parents a9db402 + 2be3126 commit 3d7b634
Show file tree
Hide file tree
Showing 10 changed files with 465 additions and 265 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ghostunnel: $(SOURCE_FILES)

# Test binary with coverage instrumentation
ghostunnel.test: $(SOURCE_FILES)
go test -c -covermode=count -coverpkg .,./auth,./certloader
go test -c -covermode=count -coverpkg .,./auth,./certloader,./proxy

# Clean build output
clean:
Expand All @@ -26,6 +26,7 @@ unit:
go test -v -covermode=count -coverprofile=coverage-unit-test-base.out .
go test -v -covermode=count -coverprofile=coverage-unit-test-auth.out ./auth
go test -v -covermode=count -coverprofile=coverage-unit-test-certloader.out ./certloader
go test -v -covermode=count -coverprofile=coverage-unit-test-proxy.out ./proxy
.PHONY: unit

# Run integration tests
Expand Down
70 changes: 46 additions & 24 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/pprof"
"os"
"runtime"
"strings"
"sync"
"time"

"github.com/cyberdelia/go-metrics-graphite"
Expand All @@ -38,6 +36,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/square/ghostunnel/auth"
"github.com/square/ghostunnel/certloader"
"github.com/square/ghostunnel/proxy"
"github.com/square/go-sq-metrics"
"gopkg.in/alecthomas/kingpin.v2"
)
Expand Down Expand Up @@ -409,13 +408,12 @@ func serverListen(context *Context) error {
return err
}

proxy := &proxy{
quit: 0,
listener: tls.NewListener(listener, config),
handlers: &sync.WaitGroup{},
connectTimeout: *timeoutDuration,
dial: context.dial,
}
p := proxy.New(
tls.NewListener(listener, config),
*timeoutDuration,
context.dial,
logger,
)

if *statusAddress != "" {
err := context.serveStatus()
Expand All @@ -427,11 +425,11 @@ func serverListen(context *Context) error {

logger.Printf("listening for connections on %s", (*serverListenAddress).String())

go proxy.accept()
go p.Accept()

context.status.Listening()
context.signalHandler(proxy, []io.Closer{listener})
proxy.handlers.Wait()
context.signalHandler(p)
p.Wait()

return nil
}
Expand All @@ -456,13 +454,12 @@ func clientListen(context *Context) error {
ul.SetUnlinkOnClose(true)
}

proxy := &proxy{
quit: 0,
listener: listener,
handlers: &sync.WaitGroup{},
connectTimeout: *timeoutDuration,
dial: context.dial,
}
p := proxy.New(
listener,
*timeoutDuration,
context.dial,
logger,
)

if *statusAddress != "" {
err := context.serveStatus()
Expand All @@ -474,11 +471,11 @@ func clientListen(context *Context) error {

logger.Printf("listening for connections on %s", *clientListenAddress)

go proxy.accept()
go p.Accept()

context.status.Listening()
context.signalHandler(proxy, []io.Closer{listener})
proxy.handlers.Wait()
context.signalHandler(p)
p.Wait()

return nil
}
Expand Down Expand Up @@ -511,7 +508,7 @@ func (context *Context) serveStatus() error {
}

var listener net.Listener
if network == unixSocket {
if network == "unix" {
listener, err = net.Listen(network, address)
listener.(*net.UnixListener).SetUnlinkOnClose(true)
} else {
Expand All @@ -523,7 +520,7 @@ func (context *Context) serveStatus() error {
return err
}

if network != unixSocket {
if network != "unix" {
listener = tls.NewListener(listener, config)
}

Expand Down Expand Up @@ -599,3 +596,28 @@ func clientBackendDialer(cert certloader.Certificate, network, address, host str
d := certloader.DialerWithCertificate(cert, config, *timeoutDuration, dialer)
return func() (net.Conn, error) { return d.Dial(network, address) }, nil
}

// Parse a string representing a TCP address or UNIX socket for our backend
// target. The input can be or the form "HOST:PORT" for TCP or "unix:PATH"
// for a UNIX socket.
func parseUnixOrTCPAddress(input string) (network, address, host string, err error) {
if strings.HasPrefix(input, "unix:") {
network = "unix"
address = input[5:]
return
}

host, _, err = net.SplitHostPort(input)
if err != nil {
return
}

// Make sure target address resolves
_, err = net.ResolveTCPAddr("tcp", input)
if err != nil {
return
}

network, address = "tcp", input
return
}
30 changes: 30 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,33 @@ func TestInvalidCABundle(t *testing.T) {
})
assert.NotNil(t, err, "invalid CA bundle should exit with error")
}

func TestParseUnixOrTcpAddress(t *testing.T) {
network, address, host, _ := parseUnixOrTCPAddress("unix:/tmp/foo")
if network != "unix" {
t.Errorf("unexpected network: %s", network)
}
if address != "/tmp/foo" {
t.Errorf("unexpected address: %s", address)
}
if host != "" {
t.Errorf("unexpected host: %s", host)
}

network, address, host, _ = parseUnixOrTCPAddress("localhost:8080")
if network != "tcp" {
t.Errorf("unexpected network: %s", network)
}
if address != "localhost:8080" {
t.Errorf("unexpected address: %s", address)
}
if host != "localhost" {
t.Errorf("unexpected host: %s", host)
}

_, _, _, err := parseUnixOrTCPAddress("localhost")
assert.NotNil(t, err, "was able to parse invalid host/port")

_, _, _, err = parseUnixOrTCPAddress("256.256.256.256:99999")
assert.NotNil(t, err, "was able to parse invalid host/port")
}
179 changes: 0 additions & 179 deletions net.go

This file was deleted.

Loading

0 comments on commit 3d7b634

Please sign in to comment.