diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index a0df8af20b..65ff6af530 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -8,13 +8,10 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.function.BiFunction; -import java.util.function.Function; - import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.sink.OutputCodecContext; @@ -33,18 +30,16 @@ public class LambdaCommonHandler { private static final Logger LOG = LoggerFactory.getLogger(LambdaCommonHandler.class); + private LambdaCommonHandler() { } public static boolean isSuccess(InvokeResponse response) { int statusCode = response.statusCode(); - if (statusCode < 200 || statusCode >= 300) { - return false; - } - return true; + return statusCode >= 200 && statusCode < 300; } - public static void waitForFutures(List> futureList) { + public static void waitForFutures(Collection> futureList) { if (!futureList.isEmpty()) { try { @@ -83,50 +78,25 @@ private static List createBufferBatches(Collection> record return batchedBuffers; } - public static List> sendRecords(Collection> records, + public static Map> sendRecords( + Collection> records, LambdaCommonConfig config, LambdaAsyncClient lambdaAsyncClient, - final OutputCodecContext outputCodecContext, - BiFunction>> successHandler, - Function>> failureHandler) { - // Initialize here to void multi-threading issues - // Note: By default, one instance of processor is created across threads. - //List> resultRecords = Collections.synchronizedList(new ArrayList<>()); - List> resultRecords = new ArrayList<>(); - List> futureList = new ArrayList<>(); - int totalFlushedEvents = 0; + final OutputCodecContext outputCodecContext) { List batchedBuffers = createBufferBatches(records, config.getBatchOptions(), outputCodecContext); - Map bufferToFutureMap = new HashMap<>(); + Map> bufferToFutureMap = new HashMap<>(); LOG.debug("Batch Chunks created after threshold check: {}", batchedBuffers.size()); for (Buffer buffer : batchedBuffers) { InvokeRequest requestPayload = buffer.getRequestPayload(config.getFunctionName(), - config.getInvocationType().getAwsLambdaValue()); + config.getInvocationType().getAwsLambdaValue()); CompletableFuture future = lambdaAsyncClient.invoke(requestPayload); - futureList.add(future); bufferToFutureMap.put(buffer, future); } - waitForFutures(futureList); - for (Map.Entry entry : bufferToFutureMap.entrySet()) { - CompletableFuture future = entry.getValue(); - Buffer buffer = entry.getKey(); - try { - InvokeResponse response = (InvokeResponse) future.join(); - if (isSuccess(response)) { - resultRecords.addAll(successHandler.apply(buffer, response)); - } else { - LOG.error("Lambda invoke failed with error {} ", response.statusCode()); - resultRecords.addAll(failureHandler.apply(buffer)); - } - } catch (Exception e) { - LOG.error("Exception from Lambda invocation ", e); - resultRecords.addAll(failureHandler.apply(buffer)); - } - } - return resultRecords; - + waitForFutures(bufferToFutureMap.values()); + return bufferToFutureMap; } } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index 2352678419..d456d644cb 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -6,6 +6,7 @@ package org.opensearch.dataprepper.plugins.lambda.processor; import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; +import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; @@ -18,6 +19,8 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; @@ -74,25 +77,24 @@ public class LambdaProcessor extends AbstractProcessor, Record> doExecute(Collection> records) { return records; } - List> resultRecords = Collections.synchronizedList(new ArrayList()); + List> resultRecords = new ArrayList<>(); List> recordsToLambda = new ArrayList<>(); for (Record record : records) { final Event event = record.getData(); // If the condition is false, add the event to resultRecords as-is - if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { + if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, + event)) { resultRecords.add(record); continue; } recordsToLambda.add(record); } - try { - resultRecords.addAll( - lambdaCommonHandler.sendRecords(recordsToLambda, lambdaProcessorConfig, lambdaAsyncClient, - new OutputCodecContext(), - (inputBuffer, response) -> { - Duration latency = inputBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsSuccessCounter.increment(); - return convertLambdaResponseToEvent(inputBuffer, response); - }, - (inputBuffer) -> { - Duration latency = inputBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsFailedCounter.increment(); - return addFailureTags(inputBuffer.getRecords()); - }) - - ); - } catch (Exception e) { - LOG.info("Exception in doExecute"); - numberOfRecordsFailedCounter.increment(recordsToLambda.size()); - resultRecords.addAll(addFailureTags(recordsToLambda)); + + Map> bufferToFutureMap = LambdaCommonHandler.sendRecords( + recordsToLambda, lambdaProcessorConfig, lambdaAsyncClient, + new OutputCodecContext()); + + for (Map.Entry> entry : bufferToFutureMap.entrySet()) { + CompletableFuture future = entry.getValue(); + Buffer inputBuffer = entry.getKey(); + try { + InvokeResponse response = future.join(); + Duration latency = inputBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + if (isSuccess(response)) { + numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsSuccessCounter.increment(); + resultRecords.addAll(convertLambdaResponseToEvent(inputBuffer, response)); + } else { + LOG.error("Lambda invoke failed with error {} ", response.statusCode()); + resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); + } + } catch (Exception e) { + LOG.error("Exception from Lambda invocation ", e); + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsFailedCounter.increment(); + resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); + } } return resultRecords; } @@ -182,7 +187,7 @@ public Collection> doExecute(Collection> records) { * 2. If it is not an array, then create one event per response. */ List> convertLambdaResponseToEvent(Buffer flushedBuffer, - final InvokeResponse lambdaResponse) { + final InvokeResponse lambdaResponse) { InputCodec responseCodec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting); List> originalRecords = flushedBuffer.getRecords(); try { @@ -191,8 +196,10 @@ List> convertLambdaResponseToEvent(Buffer flushedBuffer, List> resultRecords = new ArrayList<>(); SdkBytes payload = lambdaResponse.payload(); // Handle null or empty payload - if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { - LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events"); + if (payload == null || payload.asByteArray() == null + || payload.asByteArray().length == 0) { + LOG.warn(NOISY, + "Lambda response payload is null or empty, dropping the original events"); } else { InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); //Convert to response codec @@ -206,9 +213,10 @@ List> convertLambdaResponseToEvent(Buffer flushedBuffer, } LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " + - "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), - flushedBuffer.getSize()); - responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); + "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), + flushedBuffer.getSize()); + responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, + flushedBuffer); } return resultRecords; } catch (Exception e) { diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java index 7f840c4cf5..25c4bf2e1c 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java @@ -5,13 +5,16 @@ package org.opensearch.dataprepper.plugins.lambda.sink; +import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; + import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Timer; import java.time.Duration; -import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; @@ -37,6 +40,7 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; @DataPrepperPlugin(name = "aws_lambda", pluginType = Sink.class, pluginConfigurationType = LambdaSinkConfig.class) public class LambdaSink extends AbstractSink> { @@ -94,7 +98,7 @@ public LambdaSink(final PluginSetting pluginSetting, this.requestPayloadMetric = pluginMetrics.summary(REQUEST_PAYLOAD_SIZE); this.responsePayloadMetric = pluginMetrics.summary(RESPONSE_PAYLOAD_SIZE); ClientOptions clientOptions = lambdaSinkConfig.getClientOptions(); - if(clientOptions == null){ + if (clientOptions == null) { clientOptions = new ClientOptions(); } this.lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient( @@ -147,28 +151,37 @@ public void doOutput(final Collection> records) { } //Result from lambda is not currently processes. - LambdaCommonHandler.sendRecords(records, + Map> bufferToFutureMap = LambdaCommonHandler.sendRecords( + records, lambdaSinkConfig, lambdaAsyncClient, - outputCodecContext, - (inputBuffer, invokeResponse) -> { - Duration latency = inputBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + outputCodecContext); + + for (Map.Entry> entry : bufferToFutureMap.entrySet()) { + CompletableFuture future = entry.getValue(); + Buffer inputBuffer = entry.getKey(); + try { + InvokeResponse response = future.join(); + Duration latency = inputBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + if (isSuccess(response)) { numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); numberOfRequestsSuccessCounter.increment(); releaseEventHandlesPerBatch(true, inputBuffer); - return new ArrayList<>(); - }, - (inputBuffer) -> { - Duration latency = inputBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsFailedCounter.increment(); + } else { + LOG.error("Lambda invoke failed with error {} ", response.statusCode()); handleFailure(new RuntimeException("failed"), inputBuffer); - return new ArrayList<>(); - }); + } + } catch (Exception e) { + LOG.error("Exception from Lambda invocation ", e); + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsFailedCounter.increment(); + handleFailure(new RuntimeException("failed"), inputBuffer); + } + } } + void handleFailure(Throwable throwable, Buffer flushedBuffer) { try { if (flushedBuffer.getEventCount() > 0) { diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java index 2c7f27654a..44be8f5dd1 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java @@ -4,13 +4,20 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; @@ -26,15 +33,6 @@ import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.BiFunction; -import java.util.function.Function; - -import static org.mockito.ArgumentMatchers.any; - @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) class LambdaCommonHandlerTest { @@ -85,18 +83,18 @@ void testSendRecords() { when(config.getFunctionName()).thenReturn("testFunction"); when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) - .thenReturn(CompletableFuture.completedFuture(InvokeResponse.builder().statusCode(200).build())); + .thenReturn( + CompletableFuture.completedFuture(InvokeResponse.builder().statusCode(200).build())); Event mockEvent = mock(Event.class); when(mockEvent.toMap()).thenReturn(Collections.singletonMap("testKey", "testValue")); List> records = Collections.singletonList(new Record<>(mockEvent)); - BiFunction>> successHandler = (buffer, response) -> new ArrayList<>(); - Function>> failureHandler = (buffer) -> new ArrayList<>(); - - List> result = LambdaCommonHandler.sendRecords(records, config, lambdaAsyncClient, outputCodecContext, successHandler, failureHandler); + Map> bufferCompletableFutureMap = LambdaCommonHandler.sendRecords( + records, config, lambdaAsyncClient, + outputCodecContext); - assertNotNull(result); + assertNotNull(bufferCompletableFutureMap); verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); } @@ -112,11 +110,8 @@ void testSendRecordsWithNullKeyName() { when(mockEvent.toMap()).thenReturn(Collections.singletonMap("testKey", "testValue")); List> records = Collections.singletonList(new Record<>(mockEvent)); - BiFunction>> successHandler = (buffer, response) -> new ArrayList<>(); - Function>> failureHandler = (buffer) -> new ArrayList<>(); - assertThrows(NullPointerException.class, () -> - LambdaCommonHandler.sendRecords(records, config, lambdaAsyncClient, outputCodecContext, successHandler, failureHandler) + LambdaCommonHandler.sendRecords(records, config, lambdaAsyncClient, outputCodecContext) ); } @@ -128,17 +123,16 @@ void testSendRecordsWithFailure() { when(config.getFunctionName()).thenReturn("testFunction"); when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) - .thenReturn(CompletableFuture.failedFuture(new RuntimeException("Test exception"))); + .thenReturn(CompletableFuture.failedFuture(new RuntimeException("Test exception"))); List> records = new ArrayList<>(); records.add(new Record<>(mock(Event.class))); - BiFunction>> successHandler = (buffer, response) -> new ArrayList<>(); - Function>> failureHandler = (buffer) -> new ArrayList<>(); - - List> result = LambdaCommonHandler.sendRecords(records, config, lambdaAsyncClient, outputCodecContext, successHandler, failureHandler); + Map> bufferCompletableFutureMap = LambdaCommonHandler.sendRecords( + records, config, lambdaAsyncClient, + outputCodecContext); - assertNotNull(result); + assertNotNull(bufferCompletableFutureMap); verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java index 89489e7ab1..49d7ffc379 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java @@ -46,6 +46,10 @@ class InMemoryBufferTest { @Mock private LambdaAsyncClient lambdaAsyncClient; + public static Record getSampleRecord() { + Event event = JacksonEvent.fromMessage(String.valueOf(UUID.randomUUID())); + return new Record<>(event); + } @Test void test_with_write_event_into_buffer() { @@ -87,11 +91,6 @@ void test_with_write_event_into_buffer_and_flush_toLambda() { }); } - private Record getSampleRecord() { - Event event = JacksonEvent.fromMessage(String.valueOf(UUID.randomUUID())); - return new Record<>(event); - } - @Test void test_uploadedToLambda_success() { // Mock the response of the invoke method @@ -146,11 +145,4 @@ void test_uploadedToLambda_fails() { } - private byte[] generateByteArray() { - byte[] bytes = new byte[1000]; - for (int i = 0; i < 1000; i++) { - bytes[i] = (byte) i; - } - return bytes; - } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index 15749b853e..70c0135179 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -21,12 +21,12 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA; import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED; import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS; import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA; - +import static org.opensearch.dataprepper.plugins.lambda.sink.LambdaSinkTest.getSampleRecord; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.createLambdaConfigurationFromYaml; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Timer; @@ -42,8 +42,8 @@ import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoSettings; @@ -53,8 +53,6 @@ import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.codec.InputCodec; -import org.opensearch.dataprepper.model.codec.OutputCodec; -import org.opensearch.dataprepper.model.configuration.PluginModel; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.DefaultEventHandle; import org.opensearch.dataprepper.model.event.Event; @@ -66,6 +64,7 @@ import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; import software.amazon.awssdk.core.SdkBytes; @@ -76,6 +75,7 @@ @MockitoSettings(strictness = Strictness.LENIENT) public class LambdaProcessorTest { + // Mock dependencies @Mock private AwsAuthenticationOptions awsAuthenticationOptions; @@ -104,8 +104,6 @@ public class LambdaProcessorTest { @Mock private InputCodec responseCodec; - @Mock - private OutputCodec requestCodec; @Mock private Counter numberOfRecordsSuccessCounter; @@ -124,9 +122,6 @@ public class LambdaProcessorTest { @Mock private Timer lambdaLatencyMetric; - @Captor - private ArgumentCaptor>> consumerCaptor; - // The class under test private LambdaProcessor lambdaProcessor; @@ -140,9 +135,9 @@ public void setUp() throws Exception { when(pluginMetrics.counter(eq(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED))).thenReturn( numberOfRecordsFailedCounter); when(pluginMetrics.counter(eq(NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA))).thenReturn( - numberOfRecordsSuccessCounter); + numberOfRecordsSuccessCounter); when(pluginMetrics.counter(eq(NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA))).thenReturn( - numberOfRecordsFailedCounter); + numberOfRecordsFailedCounter); when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); when(pluginMetrics.gauge(anyString(), any(AtomicLong.class))).thenAnswer( invocation -> invocation.getArgument(1)); @@ -150,7 +145,8 @@ public void setUp() throws Exception { ClientOptions clientOptions = new ClientOptions(); when(lambdaProcessorConfig.getClientOptions()).thenReturn(clientOptions); when(lambdaProcessorConfig.getFunctionName()).thenReturn("test-function"); - when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn( + awsAuthenticationOptions); when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); BatchOptions batchOptions = mock(BatchOptions.class); ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); @@ -169,18 +165,6 @@ public void setUp() throws Exception { when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(30)); when(batchOptions.getKeyName()).thenReturn("key"); - // Mock Response Codec Configuration - PluginModel responseCodecConfig = lambdaProcessorConfig.getResponseCodecConfig(); - PluginSetting responseCodecPluginSetting; - - if (responseCodecConfig == null) { - // Default to JsonInputCodec with default settings - responseCodecPluginSetting = new PluginSetting("json", Collections.emptyMap()); - } else { - responseCodecPluginSetting = new PluginSetting(responseCodecConfig.getPluginName(), - responseCodecConfig.getPluginSettings()); - } - // Mock PluginFactory to return the mocked responseCodec when(pluginFactory.loadPlugin(eq(InputCodec.class), any(PluginSetting.class))).thenReturn( responseCodec); @@ -213,16 +197,18 @@ private void populatePrivateFields() throws Exception { setPrivateField(lambdaProcessor, "numberOfRecordsSuccessCounter", numberOfRecordsSuccessCounter); setPrivateField(lambdaProcessor, "numberOfRequestsSuccessCounter", - numberOfRequestsSuccessCounter); - setPrivateField(lambdaProcessor, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); - setPrivateField(lambdaProcessor, "numberOfRequestsFailedCounter", numberOfRequestsFailedCounter); + numberOfRequestsSuccessCounter); + setPrivateField(lambdaProcessor, "numberOfRecordsFailedCounter", + numberOfRecordsFailedCounter); + setPrivateField(lambdaProcessor, "numberOfRequestsFailedCounter", + numberOfRequestsFailedCounter); setPrivateField(lambdaProcessor, "tagsOnMatchFailure", tagsOnMatchFailure); setPrivateField(lambdaProcessor, "lambdaCommonHandler", lambdaCommonHandler); } // Helper method to set private fields via reflection private void setPrivateField(Object targetObject, String fieldName, Object value) - throws Exception { + throws Exception { Field field = targetObject.getClass().getDeclaredField(fieldName); field.setAccessible(true); field.set(targetObject, value); @@ -245,9 +231,11 @@ public void testProcessorDefaults() { // Test ClientOptions defaults ClientOptions clientOptions = defaultConfig.getClientOptions(); assertNotNull(clientOptions); - assertEquals(ClientOptions.DEFAULT_CONNECTION_RETRIES, clientOptions.getMaxConnectionRetries()); + assertEquals(ClientOptions.DEFAULT_CONNECTION_RETRIES, + clientOptions.getMaxConnectionRetries()); assertEquals(ClientOptions.DEFAULT_API_TIMEOUT, clientOptions.getApiCallTimeout()); - assertEquals(ClientOptions.DEFAULT_CONNECTION_TIMEOUT, clientOptions.getConnectionTimeout()); + assertEquals(ClientOptions.DEFAULT_CONNECTION_TIMEOUT, + clientOptions.getConnectionTimeout()); assertEquals(ClientOptions.DEFAULT_MAXIMUM_CONCURRENCY, clientOptions.getMaxConcurrency()); assertEquals(ClientOptions.DEFAULT_BASE_DELAY, clientOptions.getBaseDelay()); assertEquals(ClientOptions.DEFAULT_MAX_BACKOFF, clientOptions.getMaxBackoff()); @@ -257,12 +245,16 @@ public void testProcessorDefaults() { assertNotNull(batchOptions); } - @Test - public void testDoExecute_WithExceptionDuringProcessing() throws Exception { + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-success-config.yaml"}) + public void testDoExecute_WithExceptionDuringProcessing(String configFileName) { // Arrange - Event event = mock(Event.class); - Record record = new Record<>(event); - List> records = Collections.singletonList(record); + List> records = Collections.singletonList(getSampleRecord()); + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml( + configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginMetrics, + lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); // make batch options null to generate exception when(lambdaProcessorConfig.getBatchOptions()).thenReturn(null); @@ -347,7 +339,8 @@ public void testDoExecute_WhenConditionFalse() { Collection> result = lambdaProcessor.doExecute(records); // Assert - assertEquals(1, result.size(), "Result should contain one record as the condition is false."); + assertEquals(1, result.size(), + "Result should contain one record as the condition is false."); verify(numberOfRecordsSuccessCounter, never()).increment(anyDouble()); verify(numberOfRecordsFailedCounter, never()).increment(anyDouble()); } @@ -377,16 +370,16 @@ public void testDoExecute_SuccessfulProcessing() throws Exception { when(bufferMock.getRecords()).thenReturn(Collections.singletonList(record)); doAnswer(invocation -> { - InputStream inputStream = invocation.getArgument(0); - @SuppressWarnings("unchecked") - Consumer> consumer = invocation.getArgument(1); + invocation.getArgument(0); + @SuppressWarnings("unchecked") + Consumer> consumer = invocation.getArgument(1); - // Simulate parsing by providing a mocked event - Event parsedEvent = mock(Event.class); - Record parsedRecord = new Record<>(parsedEvent); - consumer.accept(parsedRecord); + // Simulate parsing by providing a mocked event + Event parsedEvent = mock(Event.class); + Record parsedRecord = new Record<>(parsedEvent); + consumer.accept(parsedRecord); - return null; + return null; }).when(responseCodec).parse(any(InputStream.class), any(Consumer.class)); // Act @@ -398,7 +391,8 @@ public void testDoExecute_SuccessfulProcessing() throws Exception { } @Test - public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProcessing() throws Exception { + public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProcessing() + throws Exception { // Arrange when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); @@ -410,7 +404,7 @@ public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProc // Mock the responseCodec.parse to add two events doAnswer(invocation -> { - InputStream inputStream = invocation.getArgument(0); + invocation.getArgument(0); @SuppressWarnings("unchecked") Consumer> consumer = invocation.getArgument(1); Event parsedEvent1 = mock(Event.class); @@ -436,7 +430,8 @@ public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProc when(bufferMock.getEventCount()).thenReturn(2); // Act - List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, invokeResponse); + List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, + invokeResponse); // Assert assertEquals(2, resultRecords.size(), "ResultRecords should contain two records."); @@ -446,7 +441,8 @@ public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProc } @Test - public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulProcessing() throws Exception { + public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulProcessing() + throws Exception { // Arrange // Set responseEventsMatch to false when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false); @@ -459,7 +455,7 @@ public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulPr // Mock the responseCodec.parse to add three parsed events doAnswer(invocation -> { - InputStream inputStream = invocation.getArgument(0); + invocation.getArgument(0); @SuppressWarnings("unchecked") Consumer> consumer = invocation.getArgument(1); @@ -497,7 +493,8 @@ public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulPr when(bufferMock.getEventCount()).thenReturn(2); // Act - List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, invokeResponse); + List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, + invokeResponse); // Assert // Verify that three records are added to the result assertEquals(3, resultRecords.size(), "ResultRecords should contain three records."); diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java new file mode 100644 index 0000000000..93551acda6 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java @@ -0,0 +1,42 @@ +package org.opensearch.dataprepper.plugins.lambda.utils; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import java.io.IOException; +import java.io.InputStream; +import org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessorConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class LambdaTestSetupUtil { + + private static final Logger log = LoggerFactory.getLogger(LambdaTestSetupUtil.class); + + public static ObjectMapper getObjectMapper() { + return new ObjectMapper( + new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)).registerModule( + new JavaTimeModule()); + } + + private static InputStream getResourceAsStream(String resourceName) { + InputStream inputStream = Thread.currentThread().getContextClassLoader() + .getResourceAsStream(resourceName); + if (inputStream == null) { + inputStream = LambdaTestSetupUtil.class.getResourceAsStream("/" + resourceName); + } + return inputStream; + } + + public static LambdaProcessorConfig createLambdaConfigurationFromYaml(String fileName) { + ObjectMapper objectMapper = getObjectMapper(); + try (InputStream inputStream = getResourceAsStream(fileName)) { + return objectMapper.readValue(inputStream, LambdaProcessorConfig.class); + } catch (IOException ex) { + log.error("Failed to parse pipeline Yaml", ex); + throw new RuntimeException(ex); + } + } + +} diff --git a/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-success-config.yaml b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-success-config.yaml new file mode 100644 index 0000000000..e6661d72fd --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-success-config.yaml @@ -0,0 +1,13 @@ +function_name: "lambdaProcessorTest" +response_events_match: true +tags_on_failure: [ "lambda_failure" ] +batch: + key_name: "osi_key" + threshold: + event_count: 100 + maximum_size: 1mb + event_collect_timeout: 335 +aws: + region: "us-east-1" + sts_role_arn: "arn:aws:iam::1234567890:role/sample-pipeine-role" +