Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Hai Yan <oeyh@amazon.com>
  • Loading branch information
oeyh committed Feb 21, 2025
1 parent a4a787e commit 82538d0
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ public void setSourceSchema(String sourceSchema) {

@JsonIgnore
public String getFullSourceTableName() {
if (EngineType.fromString(engineType) == EngineType.MYSQL) {
if (EngineType.fromString(engineType).isMySql()) {
return sourceDatabase + "." + sourceTable;
} else if (EngineType.fromString(engineType) == EngineType.POSTGRES) {
} else if (EngineType.fromString(engineType).isPostgres()) {
return sourceDatabase + "." + sourceSchema + "." + sourceTable;
} else {
throw new RuntimeException("Unsupported engine type: " + engineType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ public void run() {
}

private void transformEvent(final Event event, final String fullTableName, final EngineType engineType) {
if (engineType == EngineType.MYSQL) {
if (engineType.isMySql()) {
Map<String, String> columnDataTypeMap = dbTableMetadata.getTableColumnDataTypeMap().get(fullTableName);
for (Map.Entry<String, Object> entry : event.toMap().entrySet()) {
final Object data = MySQLDataTypeHelper.getDataByColumnType(MySQLDataType.byDataType(columnDataTypeMap.get(entry.getKey())), entry.getKey(),
entry.getValue(), null);
event.put(entry.getKey(), data);
}
}
if (engineType == EngineType.POSTGRES) {
if (engineType.isPostgres()) {
Map<String, String> columnDataTypeMap = dbTableMetadata.getTableColumnDataTypeMap().get(fullTableName);
for (Map.Entry<String, Object> entry : event.toMap().entrySet()) {
final Object data = PostgresDataTypeHelper.getDataByColumnType(PostgresDataType.byDataType(columnDataTypeMap.get(entry.getKey())), entry.getKey(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,18 @@ public LogicalReplicationEventProcessor(final StreamPartition streamPartition,
eventProcessingErrorCounter = pluginMetrics.counter(REPLICATION_LOG_PROCESSING_ERROR_COUNT);
}

public static LogicalReplicationEventProcessor create(final StreamPartition streamPartition,
final RdsSourceConfig sourceConfig,
final Buffer<Record<Event>> buffer,
final String s3Prefix,
final PluginMetrics pluginMetrics,
final LogicalReplicationClient logicalReplicationClient,
final StreamCheckpointer streamCheckpointer,
final AcknowledgementSetManager acknowledgementSetManager) {
return new LogicalReplicationEventProcessor(streamPartition, sourceConfig, buffer, s3Prefix, pluginMetrics,
logicalReplicationClient, streamCheckpointer, acknowledgementSetManager);
}

public void process(ByteBuffer msg) {
// Message processing logic:
// If it's a BEGIN, note its LSN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ public StreamCheckpointer(final EnhancedSourceCoordinator sourceCoordinator,
}

public void checkpoint(final EngineType engineType, final ChangeEventStatus changeEventStatus) {
if (engineType == EngineType.MYSQL) {
if (engineType.isMySql()) {
checkpoint(changeEventStatus.getBinlogCoordinate());
} else if (engineType == EngineType.POSTGRES) {
} else if (engineType.isPostgres()) {
checkpoint(changeEventStatus.getLogSequenceNumber());
} else {
throw new IllegalArgumentException("Unsupported engine type " + engineType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private void refreshTask(RdsSourceConfig sourceConfig) {
streamCheckpointer, acknowledgementSetManager, dbTableMetadata, cascadeActionDetector));
} else {
final LogicalReplicationClient logicalReplicationClient = (LogicalReplicationClient) replicationLogClient;
logicalReplicationClient.setEventProcessor(new LogicalReplicationEventProcessor(
logicalReplicationClient.setEventProcessor(LogicalReplicationEventProcessor.create(
streamPartition, sourceConfig, buffer, s3Prefix, pluginMetrics, logicalReplicationClient,
streamCheckpointer, acknowledgementSetManager));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import org.apache.parquet.avro.AvroParquetReader;
import org.apache.parquet.hadoop.ParquetReader;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.MockedStatic;
Expand All @@ -30,6 +31,7 @@
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
import org.opensearch.dataprepper.plugins.codec.parquet.ParquetInputCodec;
import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType;
import org.opensearch.dataprepper.plugins.source.rds.converter.ExportRecordConverter;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.DataFilePartition;
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.DataFileProgressState;
Expand Down Expand Up @@ -113,14 +115,16 @@ void setUp() {
when(pluginMetrics.summary(BYTES_PROCESSED)).thenReturn(bytesProcessedSummary);
}

@Test
void test_run_success() throws Exception {
@ParameterizedTest
@EnumSource(EngineType.class)
void test_run_success(EngineType engineType) throws Exception {
final String bucket = UUID.randomUUID().toString();
final String key = UUID.randomUUID().toString();
when(dataFilePartition.getBucket()).thenReturn(bucket);
when(dataFilePartition.getKey()).thenReturn(key);
final DataFileProgressState progressState = mock(DataFileProgressState.class, RETURNS_DEEP_STUBS);
when(dataFilePartition.getProgressState()).thenReturn(Optional.of(progressState));
when(progressState.getEngineType()).thenReturn(engineType.toString());

InputStream inputStream = mock(InputStream.class);
when(s3ObjectReader.readFile(bucket, key)).thenReturn(inputStream);
Expand Down Expand Up @@ -162,14 +166,16 @@ void test_run_success() throws Exception {
verify(exportRecordErrorCounter, never()).increment(1);
}

@Test
void test_flush_failure_then_error_metric_updated() throws Exception {
@ParameterizedTest
@EnumSource(EngineType.class)
void test_flush_failure_then_error_metric_updated(EngineType engineType) throws Exception {
final String bucket = UUID.randomUUID().toString();
final String key = UUID.randomUUID().toString();
when(dataFilePartition.getBucket()).thenReturn(bucket);
when(dataFilePartition.getKey()).thenReturn(key);
final DataFileProgressState progressState = mock(DataFileProgressState.class, RETURNS_DEEP_STUBS);
when(dataFilePartition.getProgressState()).thenReturn(Optional.of(progressState));
when(progressState.getEngineType()).thenReturn(engineType.toString());

InputStream inputStream = mock(InputStream.class);
when(s3ObjectReader.readFile(bucket, key)).thenReturn(inputStream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Map;
import java.util.UUID;
Expand All @@ -43,6 +44,8 @@
import static org.mockito.Mockito.when;
import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.COLUMN_NAME;
import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.TYPE_NAME;
import static org.opensearch.dataprepper.plugins.source.rds.schema.PostgresSchemaManager.DROP_PUBLICATION_SQL;
import static org.opensearch.dataprepper.plugins.source.rds.schema.PostgresSchemaManager.DROP_SLOT_SQL;

@ExtendWith(MockitoExtension.class)
class PostgresSchemaManagerTest {
Expand Down Expand Up @@ -137,6 +140,24 @@ void test_createLogicalReplicationSlot_skip_creation_if_slot_exists() throws SQL
verify(replicationConnection, never()).createReplicationSlot();
}

@Test
void test_deleteLogicalReplicationSlot_success() throws SQLException {
final String publicationName = UUID.randomUUID().toString();
final String slotName = UUID.randomUUID().toString();
final PreparedStatement dropSlotStatement = mock(PreparedStatement.class);
final Statement dropPublicationStatement = mock(Statement.class);

when(connectionManager.getConnection()).thenReturn(connection);
when(connection.prepareStatement(DROP_SLOT_SQL)).thenReturn(dropSlotStatement);
when(connection.createStatement()).thenReturn(dropPublicationStatement);

schemaManager.deleteLogicalReplicationSlot(publicationName, slotName);

verify(dropSlotStatement).setString(1, slotName);
verify(dropSlotStatement).execute();
verify(dropPublicationStatement).execute(DROP_PUBLICATION_SQL + publicationName);
}

@Test
void test_getPrimaryKeys_returns_primary_keys() throws SQLException {
final String database = UUID.randomUUID().toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager;

import java.time.Duration;
import java.util.Optional;
Expand Down Expand Up @@ -69,6 +70,9 @@ class StreamSchedulerTest {
@Mock
private StreamWorkerTaskRefresher streamWorkerTaskRefresher;

@Mock
private SchemaManager schemaManager;

private String s3Prefix;
private StreamScheduler objectUnderTest;

Expand Down Expand Up @@ -101,7 +105,7 @@ void test_given_stream_partition_then_start_stream() throws InterruptedException
executorService.submit(() -> {
try (MockedStatic<StreamWorkerTaskRefresher> streamWorkerTaskRefresherMockedStatic = mockStatic(StreamWorkerTaskRefresher.class)) {
streamWorkerTaskRefresherMockedStatic.when(() -> StreamWorkerTaskRefresher.create(eq(sourceCoordinator), eq(streamPartition), any(StreamCheckpointer.class),
eq(s3Prefix), eq(replicationLogClientFactory), eq(buffer), any(Supplier.class), eq(acknowledgementSetManager), eq(pluginMetrics)))
eq(s3Prefix), eq(replicationLogClientFactory), eq(buffer), any(Supplier.class), eq(acknowledgementSetManager), eq(pluginMetrics), eq(schemaManager)))
.thenReturn(streamWorkerTaskRefresher);
objectUnderTest.run();
}
Expand Down Expand Up @@ -131,6 +135,7 @@ void test_shutdown() throws InterruptedException {

private StreamScheduler createObjectUnderTest() {
return new StreamScheduler(
sourceCoordinator, sourceConfig, s3Prefix, replicationLogClientFactory, buffer, pluginMetrics, acknowledgementSetManager, pluginConfigObservable);
sourceCoordinator, sourceConfig, s3Prefix, replicationLogClientFactory, buffer, pluginMetrics,
acknowledgementSetManager, pluginConfigObservable, schemaManager);
}
}
Loading

0 comments on commit 82538d0

Please sign in to comment.