Skip to content

Commit

Permalink
Refactor replication slot cleanup code
Browse files Browse the repository at this point in the history
Signed-off-by: Hai Yan <oeyh@amazon.com>
  • Loading branch information
oeyh committed Feb 23, 2025
1 parent 82538d0 commit e5d0c2b
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public void start(Buffer<Record<Event>> buffer) {
}

streamScheduler = new StreamScheduler(
sourceCoordinator, sourceConfig, s3PathPrefix, replicationLogClientFactory, buffer, pluginMetrics, acknowledgementSetManager, pluginConfigObservable, schemaManager);
sourceCoordinator, sourceConfig, s3PathPrefix, replicationLogClientFactory, buffer, pluginMetrics, acknowledgementSetManager, pluginConfigObservable);
runnableList.add(streamScheduler);

if (sourceConfig.getEngine().isMySql()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition;
import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig;
import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.ExportPartition;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.LeaderPartition;
Expand Down Expand Up @@ -47,6 +48,7 @@ public class LeaderScheduler implements Runnable {
private final DbTableMetadata dbTableMetadataMetadata;

private LeaderPartition leaderPartition;
private StreamPartition streamPartition = null;
private volatile boolean shutdownRequested = false;

public LeaderScheduler(final EnhancedSourceCoordinator sourceCoordinator,
Expand Down Expand Up @@ -111,6 +113,21 @@ public void run() {

public void shutdown() {
shutdownRequested = true;

// Clean up publication and replication slot for Postgres
if (streamPartition != null) {
streamPartition.getProgressState().ifPresent(progressState -> {
if (EngineType.fromString(progressState.getEngineType()).isPostgres()) {
final PostgresStreamState postgresStreamState = progressState.getPostgresStreamState();
final String publicationName = postgresStreamState.getPublicationName();
final String replicationSlotName = postgresStreamState.getReplicationSlotName();
LOG.info("Cleaned up logical replication slot {} and publication {}",
replicationSlotName, publicationName);
((PostgresSchemaManager) schemaManager).deleteLogicalReplicationSlot(
publicationName, replicationSlotName);
}
});
}
}

private void init() {
Expand Down Expand Up @@ -184,7 +201,7 @@ private void createStreamPartition(RdsSourceConfig sourceConfig) {
postgresStreamState.setReplicationSlotName(slotName);
progressState.setPostgresStreamState(postgresStreamState);
}
StreamPartition streamPartition = new StreamPartition(sourceConfig.getDbIdentifier(), progressState);
streamPartition = new StreamPartition(sourceConfig.getDbIdentifier(), progressState);
sourceCoordinator.createPartition(streamPartition);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void connect() {
LOG.debug("Logical replication stream started. ");

if (eventProcessor != null) {
while (!disconnectRequested) {
while (!disconnectRequested && !Thread.currentThread().isInterrupted()) {
try {
// Read changes
ByteBuffer msg = stream.readPending();
Expand All @@ -85,16 +85,17 @@ public void connect() {
stream.setAppliedLSN(lsn);
} catch (Exception e) {
LOG.error("Exception while processing Postgres replication stream. ", e);
stream.close();
LOG.debug("Replication stream closed.");
throw e;
}
}
}

stream.close();
LOG.debug("Replication stream closed.");

disconnectRequested = false;
if (eventProcessor != null) {
eventProcessor.stopCheckpointManager();
}
LOG.debug("Replication stream closed successfully.");
} catch (Exception e) {
LOG.error("Exception while creating Postgres replication stream. ", e);
Expand All @@ -106,6 +107,11 @@ public void connect() {
public void disconnect() {
disconnectRequested = true;
LOG.debug("Requested to disconnect logical replication stream.");

if (eventProcessor != null) {
eventProcessor.stopCheckpointManager();
LOG.debug("Stopped checkpoint manager.");
}
}

public void setEventProcessor(LogicalReplicationEventProcessor eventProcessor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition;
import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -40,7 +39,6 @@ public class StreamScheduler implements Runnable {
private final PluginMetrics pluginMetrics;
private final AcknowledgementSetManager acknowledgementSetManager;
private final PluginConfigObservable pluginConfigObservable;
private final SchemaManager schemaManager;
private StreamWorkerTaskRefresher streamWorkerTaskRefresher;

private volatile boolean shutdownRequested = false;
Expand All @@ -52,8 +50,7 @@ public StreamScheduler(final EnhancedSourceCoordinator sourceCoordinator,
final Buffer<Record<Event>> buffer,
final PluginMetrics pluginMetrics,
final AcknowledgementSetManager acknowledgementSetManager,
final PluginConfigObservable pluginConfigObservable,
final SchemaManager schemaManager) {
final PluginConfigObservable pluginConfigObservable) {
this.sourceCoordinator = sourceCoordinator;
this.sourceConfig = sourceConfig;
this.s3Prefix = s3Prefix;
Expand All @@ -62,7 +59,6 @@ public StreamScheduler(final EnhancedSourceCoordinator sourceCoordinator,
this.pluginMetrics = pluginMetrics;
this.acknowledgementSetManager = acknowledgementSetManager;
this.pluginConfigObservable = pluginConfigObservable;
this.schemaManager = schemaManager;
}

@Override
Expand All @@ -86,7 +82,7 @@ public void run() {
streamWorkerTaskRefresher = StreamWorkerTaskRefresher.create(
sourceCoordinator, streamPartition, streamCheckpointer, s3Prefix, replicationLogClientFactory, buffer,
() -> Executors.newSingleThreadExecutor(BackgroundThreadFactory.defaultExecutorThreadFactory("rds-source-stream-worker")),
acknowledgementSetManager, pluginMetrics, schemaManager);
acknowledgementSetManager, pluginMetrics);

streamWorkerTaskRefresher.initialize(sourceConfig);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState;
import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata;
import org.opensearch.dataprepper.plugins.source.rds.resync.CascadingActionDetector;
import org.opensearch.dataprepper.plugins.source.rds.schema.PostgresSchemaManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -46,7 +43,6 @@ public class StreamWorkerTaskRefresher implements PluginConfigObserver<RdsSource
private final Supplier<ExecutorService> executorServiceSupplier;
private final PluginMetrics pluginMetrics;
private final AcknowledgementSetManager acknowledgementSetManager;
private final SchemaManager schemaManager;
private final Counter credentialsChangeCounter;
private final Counter taskRefreshErrorsCounter;

Expand All @@ -61,8 +57,7 @@ public StreamWorkerTaskRefresher(final EnhancedSourceCoordinator sourceCoordinat
final Buffer<Record<Event>> buffer,
final Supplier<ExecutorService> executorServiceSupplier,
final AcknowledgementSetManager acknowledgementSetManager,
final PluginMetrics pluginMetrics,
final SchemaManager schemaManager) {
final PluginMetrics pluginMetrics) {
this.sourceCoordinator = sourceCoordinator;
this.streamPartition = streamPartition;
this.streamCheckpointer = streamCheckpointer;
Expand All @@ -73,7 +68,6 @@ public StreamWorkerTaskRefresher(final EnhancedSourceCoordinator sourceCoordinat
this.pluginMetrics = pluginMetrics;
this.acknowledgementSetManager = acknowledgementSetManager;
this.replicationLogClientFactory = replicationLogClientFactory;
this.schemaManager = schemaManager;
this.credentialsChangeCounter = pluginMetrics.counter(CREDENTIALS_CHANGED);
this.taskRefreshErrorsCounter = pluginMetrics.counter(TASK_REFRESH_ERRORS);
}
Expand All @@ -86,10 +80,9 @@ public static StreamWorkerTaskRefresher create(final EnhancedSourceCoordinator s
final Buffer<Record<Event>> buffer,
final Supplier<ExecutorService> executorServiceSupplier,
final AcknowledgementSetManager acknowledgementSetManager,
final PluginMetrics pluginMetrics,
final SchemaManager schemaManager) {
final PluginMetrics pluginMetrics) {
return new StreamWorkerTaskRefresher(sourceCoordinator, streamPartition, streamCheckpointer, s3Prefix,
binlogClientFactory, buffer, executorServiceSupplier, acknowledgementSetManager, pluginMetrics, schemaManager);
binlogClientFactory, buffer, executorServiceSupplier, acknowledgementSetManager, pluginMetrics);
}

public void initialize(RdsSourceConfig sourceConfig) {
Expand Down Expand Up @@ -120,15 +113,6 @@ public void update(RdsSourceConfig sourceConfig) {
}

public void shutdown() {
// Clean up publication and replication slot for Postgres
if (schemaManager instanceof PostgresSchemaManager) {
Optional<StreamProgressState> progressState = streamPartition.getProgressState();
progressState.ifPresent(state -> {
final String publicationName = state.getPostgresStreamState().getPublicationName();
final String replicationSlotName = state.getPostgresStreamState().getReplicationSlotName();
((PostgresSchemaManager) schemaManager).deleteLogicalReplicationSlot(publicationName, replicationSlotName);
});
}
executorService.shutdownNow();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager;

import java.time.Duration;
import java.util.Optional;
Expand Down Expand Up @@ -70,9 +69,6 @@ class StreamSchedulerTest {
@Mock
private StreamWorkerTaskRefresher streamWorkerTaskRefresher;

@Mock
private SchemaManager schemaManager;

private String s3Prefix;
private StreamScheduler objectUnderTest;

Expand Down Expand Up @@ -105,7 +101,7 @@ void test_given_stream_partition_then_start_stream() throws InterruptedException
executorService.submit(() -> {
try (MockedStatic<StreamWorkerTaskRefresher> streamWorkerTaskRefresherMockedStatic = mockStatic(StreamWorkerTaskRefresher.class)) {
streamWorkerTaskRefresherMockedStatic.when(() -> StreamWorkerTaskRefresher.create(eq(sourceCoordinator), eq(streamPartition), any(StreamCheckpointer.class),
eq(s3Prefix), eq(replicationLogClientFactory), eq(buffer), any(Supplier.class), eq(acknowledgementSetManager), eq(pluginMetrics), eq(schemaManager)))
eq(s3Prefix), eq(replicationLogClientFactory), eq(buffer), any(Supplier.class), eq(acknowledgementSetManager), eq(pluginMetrics)))
.thenReturn(streamWorkerTaskRefresher);
objectUnderTest.run();
}
Expand Down Expand Up @@ -135,7 +131,6 @@ void test_shutdown() throws InterruptedException {

private StreamScheduler createObjectUnderTest() {
return new StreamScheduler(
sourceCoordinator, sourceConfig, s3Prefix, replicationLogClientFactory, buffer, pluginMetrics,
acknowledgementSetManager, pluginConfigObservable, schemaManager);
sourceCoordinator, sourceConfig, s3Prefix, replicationLogClientFactory, buffer, pluginMetrics, acknowledgementSetManager, pluginConfigObservable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@
import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState;
import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata;
import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata;
import org.opensearch.dataprepper.plugins.source.rds.resync.CascadingActionDetector;
import org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.PostgresSchemaManager;

import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -71,9 +68,6 @@ class StreamWorkerTaskRefresherTest {
@Mock
private ReplicationLogClientFactory replicationLogClientFactory;

@Mock
private ReplicationLogClient replicationLogClient;

@Mock
private BinlogClientWrapper binaryLogClientWrapper;

Expand Down Expand Up @@ -122,12 +116,6 @@ class StreamWorkerTaskRefresherTest {
@Mock
private GlobalState globalState;

@Mock
private MySqlSchemaManager mySqlSchemaManager;

@Mock
private PostgresSchemaManager postgresSchemaManager;

private StreamWorkerTaskRefresher streamWorkerTaskRefresher;

@Nested
Expand Down Expand Up @@ -253,7 +241,7 @@ private StreamWorkerTaskRefresher createObjectUnderTest() {

return new StreamWorkerTaskRefresher(
sourceCoordinator, streamPartition, streamCheckpointer, s3Prefix, replicationLogClientFactory, buffer,
executorServiceSupplier, acknowledgementSetManager, pluginMetrics, mySqlSchemaManager);
executorServiceSupplier, acknowledgementSetManager, pluginMetrics);
}
}

Expand Down Expand Up @@ -357,16 +345,7 @@ void test_update_when_credentials_unchanged_then_do_nothing() {

@Test
void test_shutdown() {
final StreamProgressState streamProgressState = mock(StreamProgressState.class, RETURNS_DEEP_STUBS);
final String publicationName = UUID.randomUUID().toString();
final String replicationSlotName = UUID.randomUUID().toString();
when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState));
when(streamProgressState.getPostgresStreamState().getPublicationName()).thenReturn(publicationName);
when(streamProgressState.getPostgresStreamState().getReplicationSlotName()).thenReturn(replicationSlotName);

streamWorkerTaskRefresher.shutdown();

verify(postgresSchemaManager).deleteLogicalReplicationSlot(publicationName, replicationSlotName);
verify(executorService).shutdownNow();
}

Expand All @@ -375,7 +354,7 @@ private StreamWorkerTaskRefresher createObjectUnderTest() {

return new StreamWorkerTaskRefresher(
sourceCoordinator, streamPartition, streamCheckpointer, s3Prefix, replicationLogClientFactory, buffer,
executorServiceSupplier, acknowledgementSetManager, pluginMetrics, postgresSchemaManager);
executorServiceSupplier, acknowledgementSetManager, pluginMetrics);
}
}

Expand Down

0 comments on commit e5d0c2b

Please sign in to comment.