diff --git a/ddlambda.go b/ddlambda.go index 0ac49c62..87388e5c 100644 --- a/ddlambda.go +++ b/ddlambda.go @@ -304,6 +304,10 @@ func (cfg *Config) toMetricsConfig(isExtensionRunning bool) metrics.Config { mc.EnhancedMetrics = strings.EqualFold(enhancedMetrics, "true") } + if localTest := os.Getenv("DD_LOCAL_TEST"); localTest == "1" || strings.ToLower(localTest) == "true" { + mc.LocalTest = true + } + return mc } diff --git a/ddlambda_test.go b/ddlambda_test.go index 0b3d4480..052528e8 100644 --- a/ddlambda_test.go +++ b/ddlambda_test.go @@ -9,8 +9,10 @@ package ddlambda import ( "context" + "fmt" "net/http" "net/http/httptest" + "os" "testing" "github.com/stretchr/testify/assert" @@ -51,3 +53,54 @@ func TestMetricsSubmitWithWrapper(t *testing.T) { assert.NoError(t, err) assert.True(t, called) } + +func TestToMetricConfigLocalTest(t *testing.T) { + testcases := []struct { + envs map[string]string + cval bool + }{ + { + envs: map[string]string{"DD_LOCAL_TEST": "True"}, + cval: true, + }, + { + envs: map[string]string{"DD_LOCAL_TEST": "true"}, + cval: true, + }, + { + envs: map[string]string{"DD_LOCAL_TEST": "1"}, + cval: true, + }, + { + envs: map[string]string{"DD_LOCAL_TEST": "False"}, + cval: false, + }, + { + envs: map[string]string{"DD_LOCAL_TEST": "false"}, + cval: false, + }, + { + envs: map[string]string{"DD_LOCAL_TEST": "0"}, + cval: false, + }, + { + envs: map[string]string{"DD_LOCAL_TEST": ""}, + cval: false, + }, + { + envs: map[string]string{}, + cval: false, + }, + } + + cfg := Config{} + for _, tc := range testcases { + t.Run(fmt.Sprintf("%#v", tc.envs), func(t *testing.T) { + for k, v := range tc.envs { + os.Setenv(k, v) + } + mc := cfg.toMetricsConfig(true) + assert.Equal(t, tc.cval, mc.LocalTest) + }) + } +} diff --git a/internal/metrics/listener.go b/internal/metrics/listener.go index 73f6e878..dcb309f5 100644 --- a/internal/metrics/listener.go +++ b/internal/metrics/listener.go @@ -49,6 +49,7 @@ type ( CircuitBreakerInterval time.Duration CircuitBreakerTimeout time.Duration CircuitBreakerTotalFailures uint32 + LocalTest bool } logMetric struct { @@ -143,8 +144,10 @@ func (l *Listener) HandlerFinished(ctx context.Context, err error) { } } // send a message to the Agent to flush the metrics - if err := l.extensionManager.Flush(); err != nil { - logger.Error(fmt.Errorf("error while flushing the metrics: %s", err)) + if l.config.LocalTest { + if err := l.extensionManager.Flush(); err != nil { + logger.Error(fmt.Errorf("error while flushing the metrics: %s", err)) + } } } else { // use the api diff --git a/internal/metrics/listener_test.go b/internal/metrics/listener_test.go index 07b90998..40d908d7 100644 --- a/internal/metrics/listener_test.go +++ b/internal/metrics/listener_test.go @@ -13,6 +13,8 @@ import ( "context" "encoding/json" "errors" + "fmt" + "net" "net/http" "net/http/httptest" "os" @@ -221,3 +223,27 @@ func TestSubmitEnhancedMetricsOnlyErrors(t *testing.T) { expected := "{\"m\":\"aws.lambda.enhanced.errors\",\"v\":1," assert.True(t, strings.Contains(output, expected)) } + +func TestListenerHandlerFinishedFlushes(t *testing.T) { + var called bool + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + })) + ts.Listener.Close() + ts.Listener, _ = net.Listen("tcp", "127.0.0.1:8124") + + ts.Start() + defer ts.Close() + + listener := MakeListener(Config{}, extension.BuildExtensionManager(false)) + listener.isAgentRunning = true + for _, localTest := range []bool{true, false} { + t.Run(fmt.Sprintf("%#v", localTest), func(t *testing.T) { + called = false + listener.config.LocalTest = localTest + listener.HandlerFinished(context.TODO(), nil) + assert.Equal(t, called, localTest) + }) + } +}