diff --git a/main.go b/main.go index 0b46606..9854de8 100644 --- a/main.go +++ b/main.go @@ -16,9 +16,11 @@ package main import ( + "context" "crypto/tls" "crypto/x509" "fmt" + "math/rand" "net" "net/http" "net/http/httputil" @@ -37,6 +39,7 @@ import ( type Backend struct { endpoint string proxy *httputil.ReverseProxy + httpClient *http.Client up bool healthCheckPath string healthCheckDuration int @@ -57,18 +60,30 @@ func (b *Backend) ErrorHandler(w http.ResponseWriter, r *http.Request, err error func (b *Backend) healthCheck() { healthCheckURL := b.endpoint + b.healthCheckPath for { - resp, err := http.Get(healthCheckURL) + req, err := http.NewRequest(http.MethodGet, healthCheckURL, nil) + if err != nil { + if b.logging { + fmt.Printf("%s %s fails\n", b.endpoint, err) + } + b.up = false + time.Sleep(time.Duration(b.healthCheckDuration) * time.Second) + continue + } + + resp, err := b.httpClient.Do(req) switch { case err == nil && b.healthCheckPath == "": + resp.Body.Close() fallthrough case err == nil && resp.StatusCode == http.StatusOK: + resp.Body.Close() if b.logging { fmt.Printf("%s is up\n", b.endpoint) } b.up = true default: if b.logging { - fmt.Printf("%s is down : %s\n", b.endpoint, err.Error()) + fmt.Printf("%s is down : %s\n", b.endpoint, err) } b.up = false } @@ -76,14 +91,14 @@ func (b *Backend) healthCheck() { } } -type LoadBalancer struct { +type loadBalancer struct { backends []*Backend next int // next backend the request should go to. sync.RWMutex } // Returns the next backend the request should go to. -func (lb *LoadBalancer) nextProxy() *httputil.ReverseProxy { +func (lb *loadBalancer) nextProxy() *httputil.ReverseProxy { lb.Lock() defer lb.Unlock() @@ -109,7 +124,7 @@ func (lb *LoadBalancer) nextProxy() *httputil.ReverseProxy { } // ServeHTTP - LoadBalancer implements http.Handler -func (lb *LoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (lb *loadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) { proxy := lb.nextProxy() if proxy == nil { w.WriteHeader(http.StatusBadGateway) @@ -127,13 +142,39 @@ func mustGetSystemCertPool() *x509.CertPool { return pool } +var rng = rand.New(rand.NewSource(time.Now().UTC().UnixNano())) + +type dialContext func(ctx context.Context, network, address string) (net.Conn, error) + +func newCustomDialContext(dialTimeout, dialKeepAlive time.Duration) dialContext { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{ + Timeout: dialTimeout, + KeepAlive: dialKeepAlive, + } + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + addrs, err := net.LookupHost(host) + if err != nil { + addrs = []string{host} + } + + for i := range addrs { + addrs[i] = net.JoinHostPort(addrs[i], port) + } + + return dialer.DialContext(ctx, network, addrs[rng.Intn(len(addrs))]) + } +} + func clientTransport(ctx *cli.Context, enableTLS bool) http.RoundTripper { tr := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 5 * time.Second, - KeepAlive: 5 * time.Second, - }).DialContext, + Proxy: http.ProxyFromEnvironment, + DialContext: newCustomDialContext(5*time.Second, 5*time.Second), MaxIdleConnsPerHost: 256, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, @@ -201,21 +242,6 @@ func sidekickMain(ctx *cli.Context) { } } else { endpoints = ctx.Args() - if len(endpoints) == 1 { - target, err := url.Parse(endpoints[0]) - if err != nil { - console.Fatalln(fmt.Errorf("Unable to parse input arg %s: %s", endpoints[0], err)) - } - // Single endpoint do lookup address to get all IPs - addrs, err := net.LookupHost(target.Hostname()) - if err != nil { - console.Fatalln(fmt.Errorf("Unable to lookup host %s: %s", endpoints[0], err)) - } - endpoints = make([]string, len(addrs)) - for i, addr := range addrs { - endpoints[i] = target.Scheme + "://" + net.JoinHostPort(addr, target.Port()) - } - } } var backends []*Backend @@ -238,13 +264,15 @@ func sidekickMain(ctx *cli.Context) { } proxy := httputil.NewSingleHostReverseProxy(target) proxy.Transport = clientTransport(ctx, target.Scheme == "https") - backend := &Backend{endpoint, proxy, false, healthCheckPath, healthCheckDuration, logging} + backend := &Backend{endpoint, proxy, &http.Client{ + Transport: proxy.Transport, + }, false, healthCheckPath, healthCheckDuration, logging} go backend.healthCheck() proxy.ErrorHandler = backend.ErrorHandler backends = append(backends, backend) } console.Infoln("Listening on", addr) - if err := http.ListenAndServe(addr, &LoadBalancer{ + if err := http.ListenAndServe(addr, &loadBalancer{ backends: backends, }); err != nil { console.Fatalln(err)