diff --git a/pkg/smokescreen/config.go b/pkg/smokescreen/config.go index cd3642d5..41d0c7ea 100644 --- a/pkg/smokescreen/config.go +++ b/pkg/smokescreen/config.go @@ -83,7 +83,7 @@ type Config struct { RejectResponseHandler func(*http.Response) // Custom handler to allow clients to modify successful CONNECT responses - AcceptResponseHandler func(*smokescreenContext, *http.Response) error + AcceptResponseHandler func(*SmokescreenContext, *http.Response) error // UnsafeAllowPrivateRanges inverts the default behavior, telling smokescreen to allow private IP // ranges by default (exempting loopback and unicast ranges) diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index 520d1728..b20239f2 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -63,17 +63,17 @@ const ( type ipType int -type aclDecision struct { +type ACLDecision struct { reason, role, project, outboundHost string - resolvedAddr *net.TCPAddr + ResolvedAddr *net.TCPAddr allow bool enforceWouldDeny bool } -type smokescreenContext struct { +type SmokescreenContext struct { cfg *Config start time.Time - decision *aclDecision + Decision *ACLDecision proxyType string logger *logrus.Entry requestedHost string @@ -246,17 +246,17 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { return nil, fmt.Errorf("dialContext missing required *goproxy.ProxyCtx") } - sctx, ok := pctx.UserData.(*smokescreenContext) + sctx, ok := pctx.UserData.(*SmokescreenContext) if !ok { - return nil, fmt.Errorf("dialContext missing required *smokescreenContext") + return nil, fmt.Errorf("dialContext missing required *SmokescreenContext") } - d := sctx.decision + d := sctx.Decision // If an address hasn't been resolved, does not match the original outboundHost, // or is not tcp we must re-resolve it before establishing the connection. - if d.resolvedAddr == nil || d.outboundHost != addr || network != "tcp" { + if d.ResolvedAddr == nil || d.outboundHost != addr || network != "tcp" { var err error - d.resolvedAddr, d.reason, err = safeResolve(sctx.cfg, network, addr) + d.ResolvedAddr, d.reason, err = safeResolve(sctx.cfg, network, addr) if err != nil { if _, ok := err.(denyError); ok { sctx.cfg.Log.WithFields( @@ -279,9 +279,9 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { start := time.Now() if sctx.cfg.ProxyDialTimeout == nil { - conn, err = net.DialTimeout(network, d.resolvedAddr.String(), sctx.cfg.ConnectTimeout) + conn, err = net.DialTimeout(network, d.ResolvedAddr.String(), sctx.cfg.ConnectTimeout) } else { - conn, err = sctx.cfg.ProxyDialTimeout(ctx, network, d.resolvedAddr.String(), sctx.cfg.ConnectTimeout) + conn, err = sctx.cfg.ProxyDialTimeout(ctx, network, d.ResolvedAddr.String(), sctx.cfg.ConnectTimeout) } connTime := time.Since(start) @@ -332,7 +332,7 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { // HTTPErrorHandler allows returning a custom error response when smokescreen // fails to connect to the proxy target. func HTTPErrorHandler(w io.WriteCloser, pctx *goproxy.ProxyCtx, err error) { - sctx := pctx.UserData.(*smokescreenContext) + sctx := pctx.UserData.(*SmokescreenContext) resp := rejectResponse(pctx, err) if err := resp.Write(w); err != nil { @@ -345,7 +345,7 @@ func HTTPErrorHandler(w io.WriteCloser, pctx *goproxy.ProxyCtx, err error) { } func rejectResponse(pctx *goproxy.ProxyCtx, err error) *http.Response { - sctx := pctx.UserData.(*smokescreenContext) + sctx := pctx.UserData.(*SmokescreenContext) var msg, status string var code int @@ -411,7 +411,7 @@ func configureTransport(tr *http.Transport, cfg *Config) { } } -func newContext(cfg *Config, proxyType string, req *http.Request) *smokescreenContext { +func newContext(cfg *Config, proxyType string, req *http.Request) *SmokescreenContext { start := time.Now() logger := cfg.Log.WithFields(logrus.Fields{ @@ -423,7 +423,7 @@ func newContext(cfg *Config, proxyType string, req *http.Request) *smokescreenCo LogFieldTraceID: req.Header.Get(traceHeader), }) - return &smokescreenContext{ + return &SmokescreenContext{ cfg: cfg, logger: logger, proxyType: proxyType, @@ -462,7 +462,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { // proxy requests we are able to specify the request during the call to OnResponse(). sctx := newContext(config, httpProxy, req) - // Attach smokescreenContext to goproxy.ProxyCtx + // Attach SmokescreenContext to goproxy.ProxyCtx pctx.UserData = sctx // Delete Smokescreen specific headers before goproxy forwards the request @@ -482,7 +482,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { } sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request") - sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination) + sctx.Decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination) // Returning any kind of response in this handler is goproxy's way of short circuiting // the request. The original request will never be sent, and goproxy will invoke our @@ -490,8 +490,8 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { if pctx.Error != nil { return req, rejectResponse(pctx, pctx.Error) } - if !sctx.decision.allow { - return req, rejectResponse(pctx, denyError{errors.New(sctx.decision.reason)}) + if !sctx.Decision.allow { + return req, rejectResponse(pctx, denyError{errors.New(sctx.Decision.reason)}) } // Call the custom request handler if it exists @@ -539,9 +539,9 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { // function will be called again with the previously returned response, which will // simply trigger the logHTTP function and return. proxy.OnResponse().DoFunc(func(resp *http.Response, pctx *goproxy.ProxyCtx) *http.Response { - sctx := pctx.UserData.(*smokescreenContext) + sctx := pctx.UserData.(*SmokescreenContext) - if resp != nil && pctx.Error == nil && sctx.decision.allow { + if resp != nil && pctx.Error == nil && sctx.Decision.allow { if resp.Header.Get(errorHeader) != "" { resp.Header.Del(errorHeader) } @@ -564,9 +564,9 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { // The goproxy OnResponse() function above is only called for non-https responses. if config.AcceptResponseHandler != nil { proxy.ConnectRespHandler = func(pctx *goproxy.ProxyCtx, resp *http.Response) error { - sctx, ok := pctx.UserData.(*smokescreenContext) + sctx, ok := pctx.UserData.(*SmokescreenContext) if !ok { - return fmt.Errorf("goproxy ProxyContext missing required UserData *smokescreenContext") + return fmt.Errorf("goproxy ProxyContext missing required UserData *SmokescreenContext") } return config.AcceptResponseHandler(sctx, resp) } @@ -576,7 +576,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { } func logProxy(config *Config, pctx *goproxy.ProxyCtx) { - sctx := pctx.UserData.(*smokescreenContext) + sctx := pctx.UserData.(*SmokescreenContext) fields := logrus.Fields{} @@ -589,8 +589,8 @@ func logProxy(config *Config, pctx *goproxy.ProxyCtx) { } } - decision := sctx.decision - if sctx.decision != nil { + decision := sctx.Decision + if sctx.Decision != nil { fields[LogFieldRole] = decision.role fields[LogFieldProject] = decision.project } @@ -609,7 +609,7 @@ func logProxy(config *Config, pctx *goproxy.ProxyCtx) { fields[LogFieldContentLength] = pctx.Resp.ContentLength } - if sctx.decision != nil { + if sctx.Decision != nil { fields[LogFieldDecisionReason] = decision.reason fields[LogFieldEnforceWouldDeny] = decision.enforceWouldDeny fields[LogFieldAllow] = decision.allow @@ -633,7 +633,7 @@ func logProxy(config *Config, pctx *goproxy.ProxyCtx) { } func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) { - sctx := pctx.UserData.(*smokescreenContext) + sctx := pctx.UserData.(*SmokescreenContext) // Check if requesting role is allowed to talk to remote destination, err := hostport.New(pctx.Req.Host, false) @@ -644,13 +644,13 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) { // checkIfRequestShouldBeProxied can return an error if either the resolved address is disallowed, // or if there is a DNS resolution failure. - sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination) + sctx.Decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination) if pctx.Error != nil { // DNS resolution failure return "", pctx.Error } - if !sctx.decision.allow { - return "", denyError{errors.New(sctx.decision.reason)} + if !sctx.Decision.allow { + return "", denyError{errors.New(sctx.Decision.reason)} } // Call the custom request handler if it exists @@ -881,7 +881,7 @@ func getRole(config *Config, req *http.Request) (string, error) { } } -func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destination hostport.HostPort) (*aclDecision, time.Duration, error) { +func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destination hostport.HostPort) (*ACLDecision, time.Duration, error) { decision := checkACLsForRequest(config, req, destination) var lookupTime time.Duration @@ -898,15 +898,15 @@ func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destinatio decision.allow = false decision.enforceWouldDeny = true } else { - decision.resolvedAddr = resolved + decision.ResolvedAddr = resolved } } return decision, lookupTime, nil } -func checkACLsForRequest(config *Config, req *http.Request, destination hostport.HostPort) *aclDecision { - decision := &aclDecision{ +func checkACLsForRequest(config *Config, req *http.Request, destination hostport.HostPort) *ACLDecision { + decision := &ACLDecision{ outboundHost: destination.String(), } @@ -932,9 +932,9 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport return decision } - aclDecision, err := config.EgressACL.Decide(role, destination.Host) - decision.project = aclDecision.Project - decision.reason = aclDecision.Reason + ACLDecision, err := config.EgressACL.Decide(role, destination.Host) + decision.project = ACLDecision.Project + decision.reason = ACLDecision.Reason if err != nil { config.Log.WithFields(logrus.Fields{ "error": err, @@ -947,11 +947,11 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport tags := map[string]string{ "role": decision.role, - "def_rule": fmt.Sprintf("%t", aclDecision.Default), - "project": aclDecision.Project, + "def_rule": fmt.Sprintf("%t", ACLDecision.Default), + "project": ACLDecision.Project, } - switch aclDecision.Result { + switch ACLDecision.Result { case acl.Deny: decision.enforceWouldDeny = true config.MetricsClient.IncrWithTags("acl.deny", tags, 1) @@ -970,7 +970,7 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport config.Log.WithFields(logrus.Fields{ "role": role, "destination": destination.Host, - "action": aclDecision.Result.String(), + "action": ACLDecision.Result.String(), }).Warn("Unknown ACL action") decision.reason = "Internal error" config.MetricsClient.IncrWithTags("acl.unknown_error", tags, 1) diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index 9ba44114..1ccc9a61 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -1073,7 +1073,7 @@ func TestAcceptResponseHandler(t *testing.T) { cfg, err := testConfig("test-local-srv") // set a custom AcceptResponseHandler that will set a header on every reject response - cfg.AcceptResponseHandler = func(_ *smokescreenContext, resp *http.Response) error { + cfg.AcceptResponseHandler = func(_ *SmokescreenContext, resp *http.Response) error { resp.Header.Set(testHeader, "This header is added by the AcceptResponseHandler") return nil }