Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not fuse resize-based ops and index ops (yet) #3845

Merged
merged 3 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) {
return false;
}

if (registry_utils::SchedulerTopologyChecker::hasResizeAndIndexOps(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
scheduler_type, "has resize-based ops and index ops");
return false;
}

return true;
}

Expand Down
21 changes: 21 additions & 0 deletions csrc/scheduler/registry_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <runtime/executor_kernel_arg.h>
#include <scheduler/debug_utils.h>
#include <scheduler/registry_utils.h>
#include <scheduler/tools/resize_utils.h>
#include <scheduler/utils.h>

namespace nvfuser {
Expand Down Expand Up @@ -1008,6 +1009,26 @@ bool SchedulerTopologyChecker::hasGatherToBroadcastBeforeReduction(
});
}

bool SchedulerTopologyChecker::hasResizeAndIndexOps(Fusion* fusion) {
bool has_resize = false;
bool has_index_op = false;

for (auto expr : fusion->exprs()) {
if (scheduler_tools::isResizeBasedOp(expr)) {
has_resize = true;
} else if (
expr->isOneOf<TorchGatherOp, ScatterOp, IndexSelectOp, SelectOp>()) {
has_index_op = true;
}

if (has_resize && has_index_op) {
return true;
}
}

return false;
}

} // namespace registry_utils

} // namespace nvfuser
2 changes: 2 additions & 0 deletions csrc/scheduler/registry_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class SchedulerTopologyChecker {
static bool hasGatherToBroadcastBeforeReduction(
Fusion* fusion,
const std::vector<TensorView*>& reduction_tvs);

static bool hasResizeAndIndexOps(Fusion* fusion);
};

} // namespace registry_utils
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

if (!ir_utils::hasOpsOfType<SliceOp, PadOp>(fusion)) {
if (!scheduler_tools::hasResizeBasedOps(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "No resize op to schedule");
return false;
Expand Down
8 changes: 8 additions & 0 deletions csrc/scheduler/tools/resize_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
namespace nvfuser {
namespace scheduler_tools {

bool isResizeBasedOp(Expr* expr) {
return expr->isOneOf<SliceOp, PadOp>();
}

bool hasResizeBasedOps(Fusion* fusion) {
return ir_utils::hasOpsOfType<SliceOp, PadOp>(fusion);
}

void propagateResizeToInputs(Expr* resize_tensor_op) {
NVF_ERROR(
resize_tensor_op->isA<SliceOp>() || resize_tensor_op->isA<PadOp>(),
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/tools/resize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class TensorView;

namespace scheduler_tools {

bool isResizeBasedOp(Expr* expr);

bool hasResizeBasedOps(Fusion* fusion);

// For a given resize-based tensor op such as SliceOp and PadOp, make the loop
// domain of each dependent producer tensor exact-mapped by propagating
// the iter-domain ops of the output tensor of the given op. Note that
Expand Down
51 changes: 51 additions & 0 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5912,4 +5912,55 @@ TEST_F(ResizeTest, Repro3801) {
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

// Mixing resize and index ops is not supported yet.Specifically,
// resize requires TensorIndexer, which is based on IdModel, but index
// ops like take_along_axis is not yet supported by IdModel.
TEST_F(ResizeTest, DoNotFuseResizeAndIndexOps) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

auto tv0 = makeContigConcreteTensor({128, 4095});
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor({1, 4096}, DataType::Int);
fusion.addInput(tv1);
auto tv2 = slice(
tv1,
{{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(1L)},
{IrBuilder::create<Val>(1L), IrBuilder::create<Val>(4096)}});
auto tv3 = takeAlongAxis(tv0, tv2, 0);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
auto t0 = at::randn({128, 4095}, options);
auto t1 = at::randint(0, 128, {1, 4096}, options_int);
std::vector<c10::IValue> inputs({t0, t1});

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs(inputs);
testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__);

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();

EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2)
<< "Unexpected segmentation";

// Make sure two ops are separated into their own segments
for (auto segmented_group : runtime->fusionSegments()->groups()) {
bool has_resize = false;
bool has_index_op = false;
for (auto expr : segmented_group->exprs()) {
if (scheduler_tools::isResizeBasedOp(expr)) {
has_resize = true;
} else if (
expr->isOneOf<TorchGatherOp, ScatterOp, IndexSelectOp, SelectOp>()) {
has_index_op = true;
}
}

EXPECT_NE(has_resize, has_index_op);
}
}

} // namespace nvfuser