From fa93d4fac9fd53f5fdf5a54fdfe7dfa7832f6262 Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Thu, 23 Nov 2023 07:04:57 -0600 Subject: [PATCH] [FLINK-33610][runtime] Fix unstable test CoGroupTaskTest.testCancelCoGroupTaskWhileCoGrouping (#23770) --- .../runtime/operators/CoGroupTaskTest.java | 258 ++++++------------ .../operators/CombineTaskExternalITCase.java | 3 +- .../operators/ReduceTaskExternalITCase.java | 2 +- .../runtime/operators/ReduceTaskTest.java | 2 +- .../operators/testutils/DriverTestBase.java | 12 +- 5 files changed, 88 insertions(+), 189 deletions(-) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java index 7dc0d9493c616..dc6164c67d66e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java @@ -21,6 +21,8 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; +import org.apache.flink.core.testutils.CheckedThread; +import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.operators.CoGroupTaskExternalITCase.MockCoGroupStub; import org.apache.flink.runtime.operators.testutils.DelayingInfinitiveInputIterator; import org.apache.flink.runtime.operators.testutils.DriverTestBase; @@ -36,11 +38,8 @@ import org.junit.jupiter.api.TestTemplate; -import java.util.concurrent.atomic.AtomicBoolean; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; class CoGroupTaskTest extends DriverTestBase> { private static final long SORT_MEM = 3 * 1024 * 1024; @@ -62,7 +61,7 @@ class CoGroupTaskTest extends DriverTestBase testTask = - new CoGroupDriver(); - - try { - addInputSorted( - new UniformRecordGenerator(keyCnt1, valCnt1, false), - this.comparator1.duplicate()); - addInputSorted( - new UniformRecordGenerator(keyCnt2, valCnt2, false), - this.comparator2.duplicate()); - testDriver(testTask, MockCoGroupStub.class); - } catch (Exception e) { - e.printStackTrace(); - fail("The test caused an exception."); - } + final CoGroupDriver testTask = new CoGroupDriver<>(); + + addInputSorted( + new UniformRecordGenerator(keyCnt1, valCnt1, false), this.comparator1.duplicate()); + addInputSorted( + new UniformRecordGenerator(keyCnt2, valCnt2, false), this.comparator2.duplicate()); + testDriver(testTask, MockCoGroupStub.class); assertThat(this.output.getNumberOfRecords()) .withFailMessage("Wrong result set size.") @@ -103,7 +94,7 @@ void testSortBoth1CoGroupTask() { } @TestTemplate - void testSortBoth2CoGroupTask() { + void testSortBoth2CoGroupTask() throws Exception { int keyCnt1 = 200; int valCnt1 = 2; @@ -122,21 +113,13 @@ void testSortBoth2CoGroupTask() { getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP); - final CoGroupDriver testTask = - new CoGroupDriver(); - - try { - addInputSorted( - new UniformRecordGenerator(keyCnt1, valCnt1, false), - this.comparator1.duplicate()); - addInputSorted( - new UniformRecordGenerator(keyCnt2, valCnt2, false), - this.comparator2.duplicate()); - testDriver(testTask, MockCoGroupStub.class); - } catch (Exception e) { - e.printStackTrace(); - fail("The test caused an exception."); - } + final CoGroupDriver testTask = new CoGroupDriver<>(); + + addInputSorted( + new UniformRecordGenerator(keyCnt1, valCnt1, false), this.comparator1.duplicate()); + addInputSorted( + new UniformRecordGenerator(keyCnt2, valCnt2, false), this.comparator2.duplicate()); + testDriver(testTask, MockCoGroupStub.class); assertThat(this.output.getNumberOfRecords()) .withFailMessage("Wrong result set size.") @@ -144,7 +127,7 @@ void testSortBoth2CoGroupTask() { } @TestTemplate - void testSortFirstCoGroupTask() { + void testSortFirstCoGroupTask() throws Exception { int keyCnt1 = 200; int valCnt1 = 2; @@ -163,19 +146,12 @@ void testSortFirstCoGroupTask() { getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP); - final CoGroupDriver testTask = - new CoGroupDriver(); - - try { - addInputSorted( - new UniformRecordGenerator(keyCnt1, valCnt1, false), - this.comparator1.duplicate()); - addInput(new UniformRecordGenerator(keyCnt2, valCnt2, true)); - testDriver(testTask, MockCoGroupStub.class); - } catch (Exception e) { - e.printStackTrace(); - fail("The test caused an exception."); - } + final CoGroupDriver testTask = new CoGroupDriver<>(); + + addInputSorted( + new UniformRecordGenerator(keyCnt1, valCnt1, false), this.comparator1.duplicate()); + addInput(new UniformRecordGenerator(keyCnt2, valCnt2, true)); + testDriver(testTask, MockCoGroupStub.class); assertThat(this.output.getNumberOfRecords()) .withFailMessage("Wrong result set size.") @@ -183,7 +159,7 @@ void testSortFirstCoGroupTask() { } @TestTemplate - void testSortSecondCoGroupTask() { + void testSortSecondCoGroupTask() throws Exception { int keyCnt1 = 200; int valCnt1 = 2; @@ -202,19 +178,12 @@ void testSortSecondCoGroupTask() { getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP); - final CoGroupDriver testTask = - new CoGroupDriver(); - - try { - addInput(new UniformRecordGenerator(keyCnt1, valCnt1, true)); - addInputSorted( - new UniformRecordGenerator(keyCnt2, valCnt2, false), - this.comparator2.duplicate()); - testDriver(testTask, MockCoGroupStub.class); - } catch (Exception e) { - e.printStackTrace(); - fail("The test caused an exception."); - } + final CoGroupDriver testTask = new CoGroupDriver<>(); + + addInput(new UniformRecordGenerator(keyCnt1, valCnt1, true)); + addInputSorted( + new UniformRecordGenerator(keyCnt2, valCnt2, false), this.comparator2.duplicate()); + testDriver(testTask, MockCoGroupStub.class); assertThat(this.output.getNumberOfRecords()) .withFailMessage("Wrong result set size.") @@ -222,7 +191,7 @@ void testSortSecondCoGroupTask() { } @TestTemplate - void testMergeCoGroupTask() { + void testMergeCoGroupTask() throws Exception { int keyCnt1 = 200; int valCnt1 = 2; @@ -245,15 +214,9 @@ void testMergeCoGroupTask() { getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP); - final CoGroupDriver testTask = - new CoGroupDriver(); + final CoGroupDriver testTask = new CoGroupDriver<>(); - try { - testDriver(testTask, MockCoGroupStub.class); - } catch (Exception e) { - e.printStackTrace(); - fail("The test caused an exception."); - } + testDriver(testTask, MockCoGroupStub.class); assertThat(this.output.getNumberOfRecords()) .withFailMessage("Wrong result set size.") @@ -278,15 +241,14 @@ void testFailingSortCoGroupTask() { getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP); - final CoGroupDriver testTask = - new CoGroupDriver(); + final CoGroupDriver testTask = new CoGroupDriver<>(); assertThatThrownBy(() -> testDriver(testTask, MockFailingCoGroupStub.class)) .isInstanceOf(ExpectedTestException.class); } @TestTemplate - void testCancelCoGroupTaskWhileSorting1() { + void testCancelCoGroupTaskWhileSorting1() throws Exception { int keyCnt = 10; int valCnt = 2; @@ -298,29 +260,16 @@ void testCancelCoGroupTaskWhileSorting1() { getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP); - final CoGroupDriver testTask = - new CoGroupDriver(); - - try { - addInputSorted(new DelayingInfinitiveInputIterator(1000), this.comparator1.duplicate()); - addInput(new UniformRecordGenerator(keyCnt, valCnt, true)); - } catch (Exception e) { - e.printStackTrace(); - fail("The test caused an exception."); - } + final CoGroupDriver testTask = new CoGroupDriver<>(); - final AtomicBoolean success = new AtomicBoolean(false); + addInputSorted(new DelayingInfinitiveInputIterator(1000), this.comparator1.duplicate()); + addInput(new UniformRecordGenerator(keyCnt, valCnt, true)); - Thread taskRunner = - new Thread() { + CheckedThread taskRunner = + new CheckedThread() { @Override - public void run() { - try { - testDriver(testTask, MockCoGroupStub.class); - success.set(true); - } catch (Exception ie) { - ie.printStackTrace(); - } + public void go() throws Exception { + testDriver(testTask, MockCoGroupStub.class); } }; taskRunner.start(); @@ -328,18 +277,12 @@ public void run() { TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this); tct.start(); - try { - tct.join(); - taskRunner.join(); - } catch (InterruptedException ie) { - fail("Joining threads failed"); - } - - assertThat(success).withFailMessage("The test task was not properly canceled.").isTrue(); + tct.join(); + taskRunner.sync(); } @TestTemplate - void testCancelCoGroupTaskWhileSorting2() { + void testCancelCoGroupTaskWhileSorting2() throws Exception { int keyCnt = 10; int valCnt = 2; @@ -351,29 +294,16 @@ void testCancelCoGroupTaskWhileSorting2() { getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP); - final CoGroupDriver testTask = - new CoGroupDriver(); - - try { - addInput(new UniformRecordGenerator(keyCnt, valCnt, true)); - addInputSorted(new DelayingInfinitiveInputIterator(1000), this.comparator2.duplicate()); - } catch (Exception e) { - e.printStackTrace(); - fail("The test caused an exception."); - } + final CoGroupDriver testTask = new CoGroupDriver<>(); - final AtomicBoolean success = new AtomicBoolean(false); + addInput(new UniformRecordGenerator(keyCnt, valCnt, true)); + addInputSorted(new DelayingInfinitiveInputIterator(1000), this.comparator2.duplicate()); - Thread taskRunner = - new Thread() { + CheckedThread taskRunner = + new CheckedThread() { @Override - public void run() { - try { - testDriver(testTask, MockCoGroupStub.class); - success.set(true); - } catch (Exception ie) { - ie.printStackTrace(); - } + public void go() throws Exception { + testDriver(testTask, MockCoGroupStub.class); } }; taskRunner.start(); @@ -381,20 +311,12 @@ public void run() { TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this); tct.start(); - try { - tct.join(); - taskRunner.join(); - } catch (InterruptedException ie) { - fail("Joining threads failed"); - } - - assertThat(success) - .withFailMessage("Test threw an exception even though it was properly canceled.") - .isTrue(); + tct.join(); + taskRunner.sync(); } @TestTemplate - void testCancelCoGroupTaskWhileCoGrouping() { + void testCancelCoGroupTaskWhileCoGrouping() throws Exception { int keyCnt = 100; int valCnt = 5; @@ -406,29 +328,19 @@ void testCancelCoGroupTaskWhileCoGrouping() { getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP); - final CoGroupDriver testTask = - new CoGroupDriver(); + final CoGroupDriver testTask = new CoGroupDriver<>(); - try { - addInput(new UniformRecordGenerator(keyCnt, valCnt, true)); - addInput(new UniformRecordGenerator(keyCnt, valCnt, true)); - } catch (Exception e) { - e.printStackTrace(); - fail("The test caused an exception."); - } + addInput(new UniformRecordGenerator(keyCnt, valCnt, true)); + addInput(new UniformRecordGenerator(keyCnt, valCnt, true)); - final AtomicBoolean success = new AtomicBoolean(false); + final OneShotLatch delayCoGroupProcessingLatch = new OneShotLatch(); - Thread taskRunner = - new Thread() { + CheckedThread taskRunner = + new CheckedThread() { @Override - public void run() { - try { - testDriver(testTask, MockDelayingCoGroupStub.class); - success.set(true); - } catch (Exception ie) { - ie.printStackTrace(); - } + public void go() throws Exception { + testDriver( + testTask, new MockDelayingCoGroupStub(delayCoGroupProcessingLatch)); } }; taskRunner.start(); @@ -436,16 +348,9 @@ public void run() { TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this); tct.start(); - try { - tct.join(); - taskRunner.join(); - } catch (InterruptedException ie) { - fail("Joining threads failed"); - } - - assertThat(success) - .withFailMessage("Test threw an exception even though it was properly canceled.") - .isTrue(); + tct.join(); + delayCoGroupProcessingLatch.trigger(); + taskRunner.sync(); } public static class MockFailingCoGroupStub extends RichCoGroupFunction { @@ -488,24 +393,17 @@ public static final class MockDelayingCoGroupStub extends RichCoGroupFunction { private static final long serialVersionUID = 1L; - @SuppressWarnings("unused") - @Override - public void coGroup( - Iterable records1, Iterable records2, Collector out) { + private final OneShotLatch delayCoGroupProcessingLatch; - for (Record r : records1) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - } - } + public MockDelayingCoGroupStub(OneShotLatch delayCoGroupProcessingLatch) { + this.delayCoGroupProcessingLatch = delayCoGroupProcessingLatch; + } - for (Record r : records2) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - } - } + @Override + public void coGroup( + Iterable records1, Iterable records2, Collector out) + throws InterruptedException { + delayCoGroupProcessingLatch.await(); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java index 2b7bd639f6dd2..5a90152e2c93f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java @@ -21,7 +21,6 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.GroupCombineFunction; import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.runtime.operators.testutils.DriverTestBase; import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator; import org.apache.flink.runtime.testutils.recordutils.RecordComparator; @@ -38,7 +37,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; -class CombineTaskExternalITCase extends DriverTestBase> { +class CombineTaskExternalITCase extends DriverTestBase> { private static final long COMBINE_MEM = 3 * 1024 * 1024; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java index 4928e34718eeb..3f3fc78319b87 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java @@ -44,7 +44,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; -class ReduceTaskExternalITCase extends DriverTestBase> { +class ReduceTaskExternalITCase extends DriverTestBase> { private static final Logger LOG = LoggerFactory.getLogger(ReduceTaskExternalITCase.class); @SuppressWarnings("unchecked") diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java index 12e6e05ec13c7..7611bedf6789f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java @@ -50,7 +50,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.fail; -public class ReduceTaskTest extends DriverTestBase> { +public class ReduceTaskTest extends DriverTestBase> { private static final Logger LOG = LoggerFactory.getLogger(ReduceTaskTest.class); @SuppressWarnings("unchecked") diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DriverTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DriverTestBase.java index dd08a08eeaf6d..91f52c6b46f7b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DriverTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DriverTestBase.java @@ -190,18 +190,20 @@ protected void setNumFileHandlesForSort(int numFileHandles) { this.numFileHandles = numFileHandles; } - @SuppressWarnings("rawtypes") - protected void testDriver(Driver driver, Class stubClass) throws Exception { - testDriverInternal(driver, stubClass); + /** @deprecated Use {@link #testDriver(Driver, Function)} instead. */ + @Deprecated + @SuppressWarnings({"rawtypes"}) + protected void testDriver(Driver driver, Class stubClass) throws Exception { + testDriver(driver, stubClass.getDeclaredConstructor().newInstance()); } @SuppressWarnings({"unchecked", "rawtypes"}) - protected void testDriverInternal(Driver driver, Class stubClass) throws Exception { + protected void testDriver(Driver driver, S stub) throws Exception { this.driver = driver; driver.setup(this); - this.stub = (S) stubClass.newInstance(); + this.stub = stub; // regular running logic boolean stubOpen = false;