From 28b04e246439c745ea3713bec5233c092ed8d20a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 6 Feb 2025 15:06:59 -0800 Subject: [PATCH 1/2] Fixes #3718 --- csrc/scheduler/registry.cpp | 6 ++++++ csrc/scheduler/registry_utils.cpp | 21 +++++++++++++++++++ csrc/scheduler/registry_utils.h | 2 ++ csrc/scheduler/resize.cpp | 2 +- csrc/scheduler/tools/resize_utils.cpp | 8 +++++++ csrc/scheduler/tools/resize_utils.h | 4 ++++ tests/cpp/test_resize.cpp | 30 +++++++++++++++++++++++++++ 7 files changed, 72 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 039d94b38af..3939577e0f8 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -61,6 +61,12 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) { return false; } + if (registry_utils::SchedulerTopologyChecker::hasResizeAndIndexOps(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + scheduler_type, "has resize-based ops and index ops"); + return false; + } + return true; } diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index d95d7f0942e..3d40ea3c7ba 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace nvfuser { @@ -1008,6 +1009,26 @@ bool SchedulerTopologyChecker::hasGatherToBroadcastBeforeReduction( }); } +bool SchedulerTopologyChecker::hasResizeAndIndexOps(Fusion* fusion) { + bool has_resize = false; + bool has_index_op = false; + + for (auto expr : fusion->exprs()) { + if (scheduler_tools::isResizeBasedOp(expr)) { + has_resize = true; + } else if ( + expr->isOneOf()) { + has_index_op = true; + } + + if (has_resize && has_index_op) { + return true; + } + } + + return false; +} + } // namespace registry_utils } // namespace nvfuser diff --git a/csrc/scheduler/registry_utils.h b/csrc/scheduler/registry_utils.h index fbcfa08b399..48e1462fcb2 100644 --- a/csrc/scheduler/registry_utils.h +++ b/csrc/scheduler/registry_utils.h @@ -104,6 +104,8 @@ class SchedulerTopologyChecker { static bool hasGatherToBroadcastBeforeReduction( Fusion* fusion, const std::vector& reduction_tvs); + + static bool hasResizeAndIndexOps(Fusion* fusion); }; } // namespace registry_utils diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index ef400301487..1c198aee8dc 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -68,7 +68,7 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } - if (!ir_utils::hasOpsOfType(fusion)) { + if (!scheduler_tools::hasResizeBasedOps(fusion)) { scheduler_debug_utils::canScheduleRejectReason( schedulerType(), "No resize op to schedule"); return false; diff --git a/csrc/scheduler/tools/resize_utils.cpp b/csrc/scheduler/tools/resize_utils.cpp index ddecf6bcb13..f281b40fe79 100644 --- a/csrc/scheduler/tools/resize_utils.cpp +++ b/csrc/scheduler/tools/resize_utils.cpp @@ -18,6 +18,14 @@ namespace nvfuser { namespace scheduler_tools { +bool isResizeBasedOp(Expr* expr) { + return expr->isOneOf(); +} + +bool hasResizeBasedOps(Fusion* fusion) { + return ir_utils::hasOpsOfType(fusion); +} + void propagateResizeToInputs(Expr* resize_tensor_op) { NVF_ERROR( resize_tensor_op->isA() || resize_tensor_op->isA(), diff --git a/csrc/scheduler/tools/resize_utils.h b/csrc/scheduler/tools/resize_utils.h index b9afed5effa..99e03153a37 100644 --- a/csrc/scheduler/tools/resize_utils.h +++ b/csrc/scheduler/tools/resize_utils.h @@ -16,6 +16,10 @@ class TensorView; namespace scheduler_tools { +bool isResizeBasedOp(Expr* expr); + +bool hasResizeBasedOps(Fusion* fusion); + // For a given resize-based tensor op such as SliceOp and PadOp, make the loop // domain of each dependent producer tensor exact-mapped by propagating // the iter-domain ops of the output tensor of the given op. Note that diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 38acf56950f..345dd815851 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -5754,4 +5754,34 @@ TEST_F(ResizeTest, PadAndCacheUses) { // // TODO: check vectorization factor // } +// Mixing resize and index ops is not supported yet.Specifically, +// resize requires TensorIndexer, which is based on IdModel, but index +// ops like take_along_axis is not yet supported by IdModel. +TEST_F(ResizeTest, DoNotFuseResizeAndIndexOps) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({128, 4095}); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor({1, 4096}, DataType::Int); + fusion.addInput(tv1); + auto tv2 = slice( + tv1, + {{IrBuilder::create(0L), IrBuilder::create(1L)}, + {IrBuilder::create(1L), IrBuilder::create(4096)}}); + auto tv3 = takeAlongAxis(tv0, tv2, 0); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + auto t0 = at::randn({128, 4095}, options); + auto t1 = at::randint(0, 128, {1, 4096}, options_int); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); +} + } // namespace nvfuser From 100acadde4a4f8f735f09f8ec002c8576daf764f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 6 Feb 2025 15:22:21 -0800 Subject: [PATCH 2/2] validation --- tests/cpp/test_resize.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 345dd815851..95310674a2d 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -5782,6 +5782,27 @@ TEST_F(ResizeTest, DoNotFuseResizeAndIndexOps) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto outputs = executor_cache.runFusionWithInputs(inputs); testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + + EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2) + << "Unexpected segmentation"; + + // Make sure two ops are separated into their own segments + for (auto segmented_group : runtime->fusionSegments()->groups()) { + bool has_resize = false; + bool has_index_op = false; + for (auto expr : segmented_group->exprs()) { + if (scheduler_tools::isResizeBasedOp(expr)) { + has_resize = true; + } else if ( + expr->isOneOf()) { + has_index_op = true; + } + } + + EXPECT_NE(has_resize, has_index_op); + } } } // namespace nvfuser