From 83270283170cb749130a39aa563d487c71df0ce7 Mon Sep 17 00:00:00 2001 From: Sam Perrin Date: Thu, 19 Aug 2021 15:23:37 +0100 Subject: [PATCH] Support custom root trace extraction from Lambda body content (#80) * feat: support custom root trace extraction from Lambda body * feat: run go mod tidy * feat: code review improvements * Rename Context -> TraceContext to remove ambiguity with Go context. * Rename the DefaultTraceExtractor func to getHeadersFromEventHeaders. * feat: fix conflict and add documentation --- README.md | 19 +++++++++ ddlambda.go | 9 ++++- ddlambda_example_test.go | 54 ++++++++++++++++++++++++++ internal/trace/context.go | 71 ++++++++++++++++++++-------------- internal/trace/context_test.go | 18 ++++----- internal/trace/listener.go | 21 +++++----- 6 files changed, 143 insertions(+), 49 deletions(-) create mode 100644 ddlambda_example_test.go diff --git a/README.md b/README.md index 49e981fe..e4a0d26d 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,25 @@ func handleRequest(ctx context.Context, ev events.APIGatewayProxyRequest) (event If you are also using AWS X-Ray to trace your Lambda functions, you can set the `DD_MERGE_XRAY_TRACES` environment variable to `true`, and Datadog will merge your Datadog and X-Ray traces into a single, unified trace. +### Trace Context Extraction + +To link your distributed traces, datadog-lambda-go looks for the `x-datadog-trace-id`, `x-datadog-parent-id` and `x-datadog-sampling-priority` trace `headers` in the Lambda event payload. +If the headers are found it will set the parent trace to the trace context extracted from the headers. + +It is possible to configure your own trace context extractor function if the default extractor does not support your event. + +```go +myExtractorFunc := func(ctx context.Context, ev json.RawMessage) map[string]string { + // extract x-datadog-trace-id, x-datadog-parent-id and x-datadog-sampling-priority. +} + +cfg := &ddlambda.Config{ + TraceContextExtractor: myExtractorFunc, +} +ddlambda.WrapFunction(handler, cfg) +``` + +A more complete example can be found in the `ddlambda_example_test.go` file. ## Environment Variables diff --git a/ddlambda.go b/ddlambda.go index 7b19c308..1111b7f5 100644 --- a/ddlambda.go +++ b/ddlambda.go @@ -68,6 +68,9 @@ type ( // the counter will get totally reset after CircuitBreakerInterval // default: 4 CircuitBreakerTotalFailures uint32 + // TraceContextExtractor is the function that extracts a root/parent trace context from the Lambda event body. + // See trace.DefaultTraceExtractor for an example. + TraceContextExtractor trace.ContextExtractor } ) @@ -179,7 +182,6 @@ func InvokeDryRun(callback func(ctx context.Context), cfg *Config) (interface{}, } func (cfg *Config) toTraceConfig() trace.Config { - traceConfig := trace.Config{ DDTraceEnabled: false, MergeXrayTraces: false, @@ -188,6 +190,11 @@ func (cfg *Config) toTraceConfig() trace.Config { if cfg != nil { traceConfig.DDTraceEnabled = cfg.DDTraceEnabled traceConfig.MergeXrayTraces = cfg.MergeXrayTraces + traceConfig.TraceContextExtractor = cfg.TraceContextExtractor + } + + if traceConfig.TraceContextExtractor == nil { + traceConfig.TraceContextExtractor = trace.DefaultTraceExtractor } if !traceConfig.DDTraceEnabled { diff --git a/ddlambda_example_test.go b/ddlambda_example_test.go new file mode 100644 index 00000000..d82b1683 --- /dev/null +++ b/ddlambda_example_test.go @@ -0,0 +1,54 @@ +package ddlambda_test + +import ( + "context" + "encoding/json" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "strings" + "testing" + + "github.com/aws/aws-lambda-go/events" + + ddlambda "github.com/DataDog/datadog-lambda-go" +) + +var exampleSQSExtractor = func(ctx context.Context, ev json.RawMessage) map[string]string { + eh := events.SQSEvent{} + + headers := map[string]string{} + + if err := json.Unmarshal(ev, &eh); err != nil { + return headers + } + + // Using SQS as a trigger with a batchSize=1 so its important we check for this as a single SQS message + // will drive the execution of the handler. + if len(eh.Records) != 1 { + return headers + } + + record := eh.Records[0] + + lowercaseHeaders := map[string]string{} + for k, v := range record.MessageAttributes { + if v.StringValue != nil { + lowercaseHeaders[strings.ToLower(k)] = *v.StringValue + } + } + + return lowercaseHeaders +} + +func TestCustomExtractorExample(t *testing.T) { + handler := func(ctx context.Context, event events.SQSEvent) error { + // Use the parent span retrieved from the SQS Message Attributes. + span, _ := tracer.SpanFromContext(ctx) + span.SetTag("key", "value") + return nil + } + + cfg := &ddlambda.Config{ + TraceContextExtractor: exampleSQSExtractor, + } + ddlambda.WrapFunction(handler, cfg) +} diff --git a/internal/trace/context.go b/internal/trace/context.go index 156100e8..66f82d42 100644 --- a/internal/trace/context.go +++ b/internal/trace/context.go @@ -29,21 +29,25 @@ type ( Headers map[string]string `json:"headers"` } - // TraceContext is map of headers containing a Datadog trace context + // TraceContext is map of headers containing a Datadog trace context. TraceContext map[string]string + + // ContextExtractor is a func type for extracting a root TraceContext. + ContextExtractor func(ctx context.Context, ev json.RawMessage) map[string]string ) type contextKeytype int -// traceContextKey is the key used to store a TraceContext in a Context object +// traceContextKey is the key used to store a TraceContext in a TraceContext object var traceContextKey = new(contextKeytype) -var datadogTraceContextFromEvent TraceContext +// DefaultTraceExtractor is the default trace extractor. Extracts root trace from API Gateway headers. +var DefaultTraceExtractor = getHeadersFromEventHeaders // contextWithRootTraceContext uses the incoming event and context object payloads to determine // the root TraceContext and then adds that TraceContext to the context object. -func contextWithRootTraceContext(ctx context.Context, ev json.RawMessage, mergeXrayTraces bool) (context.Context, error) { - datadogTraceContext, gotDatadogTraceContext := getDatadogTraceContextFromEvent(ctx, ev) +func contextWithRootTraceContext(ctx context.Context, ev json.RawMessage, mergeXrayTraces bool, extractor ContextExtractor) (context.Context, error) { + datadogTraceContext, gotDatadogTraceContext := getTraceContext(extractor(ctx, ev)) xrayTraceContext, errGettingXrayContext := convertXrayTraceContextFromLambdaContext(ctx) if errGettingXrayContext != nil { @@ -119,43 +123,50 @@ func createDummySubsegmentForXrayConverter(ctx context.Context, traceCtx TraceCo return nil } -// getDatadogTraceContextFromEvent extracts the Datadog trace context from an incoming Lambda event payload -// and creates a dummy X-Ray subsegment containing this information -func getDatadogTraceContextFromEvent(ctx context.Context, ev json.RawMessage) (TraceContext, bool) { - eh := eventWithHeaders{} - - traceCtx := map[string]string{} - - err := json.Unmarshal(ev, &eh) - if err != nil { - return traceCtx, false - } +func getTraceContext(context map[string]string) (TraceContext, bool) { + tc := TraceContext{} - lowercaseHeaders := map[string]string{} - for k, v := range eh.Headers { - lowercaseHeaders[strings.ToLower(k)] = v - } - - traceID, ok := lowercaseHeaders[traceIDHeader] + traceID, ok := context[traceIDHeader] if !ok { - return traceCtx, false + return tc, false } - parentID, ok := lowercaseHeaders[parentIDHeader] + parentID, ok := context[parentIDHeader] if !ok { - return traceCtx, false + return tc, false } - samplingPriority, ok := lowercaseHeaders[samplingPriorityHeader] + samplingPriority, ok := context[samplingPriorityHeader] if !ok { samplingPriority = "1" //sampler-keep } - traceCtx[samplingPriorityHeader] = samplingPriority - traceCtx[traceIDHeader] = traceID - traceCtx[parentIDHeader] = parentID + tc[samplingPriorityHeader] = samplingPriority + tc[traceIDHeader] = traceID + tc[parentIDHeader] = parentID + + return tc, true +} + +// getHeadersFromEventHeaders extracts the Datadog trace context from an incoming Lambda event payload +// and creates a dummy X-Ray subsegment containing this information. +// This is used as the DefaultTraceExtractor. +func getHeadersFromEventHeaders(ctx context.Context, ev json.RawMessage) map[string]string { + eh := eventWithHeaders{} + + headers := map[string]string{} + + err := json.Unmarshal(ev, &eh) + if err != nil { + return headers + } + + lowercaseHeaders := map[string]string{} + for k, v := range eh.Headers { + lowercaseHeaders[strings.ToLower(k)] = v + } - return traceCtx, true + return lowercaseHeaders } func convertXrayTraceContextFromLambdaContext(ctx context.Context) (TraceContext, error) { diff --git a/internal/trace/context_test.go b/internal/trace/context_test.go index 574bbd77..f48db166 100644 --- a/internal/trace/context_test.go +++ b/internal/trace/context_test.go @@ -58,7 +58,7 @@ func TestGetDatadogTraceContextForTraceMetadataNonProxyEvent(t *testing.T) { ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/apig-event-with-headers.json") - headers, ok := getDatadogTraceContextFromEvent(ctx, *ev) + headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) assert.True(t, ok) expected := TraceContext{ @@ -73,7 +73,7 @@ func TestGetDatadogTraceContextForTraceMetadataWithMixedCaseHeaders(t *testing.T ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/non-proxy-with-mixed-case-headers.json") - headers, ok := getDatadogTraceContextFromEvent(ctx, *ev) + headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) assert.True(t, ok) expected := TraceContext{ @@ -88,7 +88,7 @@ func TestGetDatadogTraceContextForTraceMetadataWithMissingSamplingPriority(t *te ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/non-proxy-with-missing-sampling-priority.json") - headers, ok := getDatadogTraceContextFromEvent(ctx, *ev) + headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) assert.True(t, ok) expected := TraceContext{ @@ -103,7 +103,7 @@ func TestGetDatadogTraceContextForInvalidData(t *testing.T) { ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/invalid.json") - _, ok := getDatadogTraceContextFromEvent(ctx, *ev) + _, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) assert.False(t, ok) } @@ -111,7 +111,7 @@ func TestGetDatadogTraceContextForMissingData(t *testing.T) { ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/non-proxy-no-headers.json") - _, ok := getDatadogTraceContextFromEvent(ctx, *ev) + _, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) assert.False(t, ok) } @@ -177,7 +177,7 @@ func TestContextWithRootTraceContextNoDatadogContext(t *testing.T) { ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/apig-event-no-headers.json") - newCTX, _ := contextWithRootTraceContext(ctx, *ev, false) + newCTX, _ := contextWithRootTraceContext(ctx, *ev, false, DefaultTraceExtractor) traceContext, _ := newCTX.Value(traceContextKey).(TraceContext) expected := TraceContext{} @@ -188,7 +188,7 @@ func TestContextWithRootTraceContextWithDatadogContext(t *testing.T) { ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/apig-event-with-headers.json") - newCTX, _ := contextWithRootTraceContext(ctx, *ev, false) + newCTX, _ := contextWithRootTraceContext(ctx, *ev, false, DefaultTraceExtractor) traceContext, _ := newCTX.Value(traceContextKey).(TraceContext) expected := TraceContext{ @@ -203,7 +203,7 @@ func TestContextWithRootTraceContextMergeXrayTracesNoDatadogContext(t *testing.T ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/apig-event-no-headers.json") - newCTX, _ := contextWithRootTraceContext(ctx, *ev, true) + newCTX, _ := contextWithRootTraceContext(ctx, *ev, true, DefaultTraceExtractor) traceContext, _ := newCTX.Value(traceContextKey).(TraceContext) expected := TraceContext{ @@ -218,7 +218,7 @@ func TestContextWithRootTraceContextMergeXrayTracesWithDatadogContext(t *testing ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/apig-event-with-headers.json") - newCTX, _ := contextWithRootTraceContext(ctx, *ev, true) + newCTX, _ := contextWithRootTraceContext(ctx, *ev, true, DefaultTraceExtractor) traceContext, _ := newCTX.Value(traceContextKey).(TraceContext) expected := TraceContext{ diff --git a/internal/trace/listener.go b/internal/trace/listener.go index dc10e4e6..f6519276 100644 --- a/internal/trace/listener.go +++ b/internal/trace/listener.go @@ -25,15 +25,17 @@ import ( type ( // Listener creates a function execution span and injects it into the context Listener struct { - ddTraceEnabled bool - mergeXrayTraces bool - extensionManager *extension.ExtensionManager + ddTraceEnabled bool + mergeXrayTraces bool + extensionManager *extension.ExtensionManager + traceContextExtractor ContextExtractor } // Config gives options for how the Listener should work Config struct { - DDTraceEnabled bool - MergeXrayTraces bool + DDTraceEnabled bool + MergeXrayTraces bool + TraceContextExtractor ContextExtractor } ) @@ -46,9 +48,10 @@ var tracerInitialized = false func MakeListener(config Config, extensionManager *extension.ExtensionManager) Listener { return Listener{ - ddTraceEnabled: config.DDTraceEnabled, - mergeXrayTraces: config.MergeXrayTraces, - extensionManager: extensionManager, + ddTraceEnabled: config.DDTraceEnabled, + mergeXrayTraces: config.MergeXrayTraces, + extensionManager: extensionManager, + traceContextExtractor: config.TraceContextExtractor, } } @@ -58,7 +61,7 @@ func (l *Listener) HandlerStarted(ctx context.Context, msg json.RawMessage) cont return ctx } - ctx, _ = contextWithRootTraceContext(ctx, msg, l.mergeXrayTraces) + ctx, _ = contextWithRootTraceContext(ctx, msg, l.mergeXrayTraces, l.traceContextExtractor) if !tracerInitialized { tracer.Start(