Skip to content

Commit

Permalink
Host irs: Allocate (#3524)
Browse files Browse the repository at this point in the history
# What
- Add `Allocate` node support in `HostIrExecutor`
- Update `MultiDeviceExecutor`'s lowering and execution to use this host
node. This allows to greatly simplify the implementation and address a
bunch of longstanding TODOs. Also, this way, host allocation is
exercised through all the multi-device test using `MultiDeviceExecutor`.


# Example
Running
```
mpirun -x NVFUSER_DUMP=host_ir -np 8 test_multidevice --gtest_filter=Gather/PipelineTestTwoStages.Communication/87
```
prints
```
%HostIrContainer { (T0_g_float[iS0{3}, ideviceIdx.x1{3}, iS2{3}, iS3{5}] (DeviceMesh{0 2 3})) -> (T3_g_float[iS11{3}, iS12{3}, iS13{3}] (DeviceMesh{0 2 3})) :
  PostOnStream (HostUnit0, Inputs:{T0_g_float[iS0{3}, ideviceIdx.x1{3}, iS2{3}, iS3{5}] (DeviceMesh{0 2 3}), }, Outputs:{T4_g_float[ideviceIdx.x15{3}, iS14{3}, iS16{3}] (DeviceMesh{0 2 3}), })
  T5_g_float[iS17{3}, iS18{3}, iS19{3}] (DeviceMesh{0 2 3}) = ALLOCATE(buffer=T5_g_float[iS17{3}, iS18{3}, iS19{3}] (DeviceMesh{0 2 3}), mem_type=global, size=27, zero_init=false, resets_to_zero=false)
  Communication 3 (type=Allgather, team=(0 2 3), input=T4_g_float[ideviceIdx.x15{3}, iS14{3}, iS16{3}] (DeviceMesh{0 2 3}), output=T5_g_float[iS17{3}, iS18{3}, iS19{3}] (DeviceMesh{0 2 3}))
  Wait Communication 3
  PostOnStream (HostUnit26, Inputs:{T5_g_float[iS17{3}, iS18{3}, iS19{3}] (DeviceMesh{0 2 3}), }, Outputs:{T3_g_float[iS11{3}, iS12{3}, iS13{3}] (DeviceMesh{0 2 3}), })

HostUnit26: Inputs={T5_g_float[iS17{3}, iS18{3}, iS19{3}] (DeviceMesh{0 2 3}), } -> Outputs={T3_g_float[iS11{3}, iS12{3}, iS13{3}] (DeviceMesh{0 2 3}), }
%kernel {
T6_l_float[iS21{3}, iS20{3}, iS22{3}] (DeviceMesh{0 2 3})
   = Set.Permute( T5_g_float[iS17{3}, iS18{3}, iS19{3}] (DeviceMesh{0 2 3}), cache_op=Streaming )
T3_g_float[iS11{3}, iS12{3}, iS13{3}] (DeviceMesh{0 2 3})
   = T6_l_float[iS21{3}, iS20{3}, iS22{3}] (DeviceMesh{0 2 3})
   + T6_l_float[iS21{3}, iS20{3}, iS22{3}] (DeviceMesh{0 2 3});
} // %kernel

HostUnit0: Inputs={T0_g_float[iS0{3}, ideviceIdx.x1{3}, iS2{3}, iS3{5}] (DeviceMesh{0 2 3}), } -> Outputs={T4_g_float[ideviceIdx.x15{3}, iS14{3}, iS16{3}] (DeviceMesh{0 2 3}), }
%kernel {
T1_l_float[iS4{3}, ideviceIdx.x5{3}, iS6{3}, rS7{5}] (DeviceMesh{0 2 3})
   = reduction( T0_g_float[iS0{3}, ideviceIdx.x1{3}, iS2{3}, iS3{5}] (DeviceMesh{0 2 3}), op = add, initial value = float(0), allreduce = false )
T4_g_float[ideviceIdx.x15{3}, iS14{3}, iS16{3}] (DeviceMesh{0 2 3})
   = Set.Permute( T1_l_float[iS4{3}, ideviceIdx.x5{3}, iS6{3}, rS7{5}] (DeviceMesh{0 2 3}), cache_op=Streaming )
} // %kernel
} // %HostIrContainer
```
  • Loading branch information
samnordmann authored Dec 10, 2024
1 parent 829f879 commit bfd2a6a
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 88 deletions.
16 changes: 16 additions & 0 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,22 @@ void HostIrEvaluator::handle(MatmulOp* matmul) {
}
}

void HostIrEvaluator::handle(kir::Allocate* allocate) {
NVF_ERROR(
allocate->buffer()->isA<TensorView>(),
"Allocation must be on a TensorView but got ",
allocate->buffer());
TensorView* tv = allocate->buffer()->as<TensorView>();
GlobalBufferInfo info =
getBufferInfos(expr_evaluator_, PrimDataType::Int, {tv}).at(0);
AliasInfo alias_info = {
.type = AllocationType::New, .aliased_io = nullptr, .hide_output = false};
c10::Device device =
communicator_ ? communicator_->device() : at::Device("cuda:0");
at::Tensor tensor = allocateTensor(info, alias_info, device, expr_evaluator_);
expr_evaluator_.bind(tv, tensor);
}

void HostIrEvaluator::unhandled(Statement* stmt) {
NVF_ERROR(stmt->isA<Expr>(), stmt, " must be an Expr");
auto* expr = stmt->as<Expr>();
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class HostIrEvaluator final : public OptOutDispatch {
void handle(EndCoalescing* end_coalescing) override;
void handle(kir::IfThenElse* if_then_else) override;
void handle(MatmulOp* matmul) override;
void handle(kir::Allocate* allocate) override;
void unhandled(Statement* stmt) override;

c10::cuda::CUDAStream getCUDAStream(Stream* stream);
Expand Down
4 changes: 2 additions & 2 deletions csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ Allocate::Allocate(
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
(passkey.ir_container_->isOneOf<kir::Kernel, hir::HostIrContainer>()),
"IR type only valid for Kernel or HostIr container.");
if (!shape.empty()) {
NVF_ERROR(
(shape.size() == 1 && shape[0]->isOneInt()) ||
Expand Down
82 changes: 8 additions & 74 deletions csrc/multidevice/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,6 @@

namespace nvfuser {

namespace {

// returns a copied fusion where the original outputs have been replaced by
// the ones given as argument
std::unique_ptr<Fusion> copyFusionAndChangeOutputs(
Fusion* fusion,
const std::vector<Val*>& outputs) {
std::unique_ptr<Fusion> fusion_copy = std::make_unique<Fusion>();
std::unordered_map<Val*, Val*> copy_to_original_map;
auto original_to_copy_cloner = Fusion::copy(fusion, fusion_copy.get());

auto original_outputs = fusion_copy->outputs();

// Remove original outputs
std::for_each(
original_outputs.begin(), original_outputs.end(), [&](auto& output) {
fusion_copy->removeOutput(output);
});

// Add new outputs
std::for_each(outputs.begin(), outputs.end(), [&](Val* const& output) {
fusion_copy->addOutput(original_to_copy_cloner.clone(output));
});

return fusion_copy;
}

// Used in distributed setting where we only want to allocate output space and
// receive output data from a different rank instead of computing them.
std::vector<at::Tensor> allocateOutputSpace(
const at::ArrayRef<c10::IValue>& inputs,
Fusion* fusion,
const c10::Device& device) {
FUSER_PERF_SCOPE("multidevice::executor::allocateOutputSpace");
auto fusion_inputs = KernelArgumentHolder::createKernelArgumentHolder(inputs);
auto expr_eval = executor_utils::bindInputs(fusion_inputs, fusion);

auto output_info =
getBufferInfos(expr_eval, PrimDataType::Int, fusion->outputs());

return allocateOutputs(fusion, output_info, device, expr_eval);
}

} // namespace

MultiDeviceExecutor::MultiDeviceExecutor(
std::unique_ptr<Fusion> fusion,
Communicator& comm,
Expand Down Expand Up @@ -138,8 +93,15 @@ MultiDeviceExecutor::MultiDeviceExecutor(
std::vector<Communication*> communications =
lowerCommunication(ir_cloner.clone(group->exprs().at(0)));
for (Communication* communication : communications) {
auto wait = IrBuilder::create<hir::Wait>(communication);
// Allocate the recv buffers of communications
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(comm_.deviceId())) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
}
hic->pushBackTopLevelExprs(communication);
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
} else {
Expand All @@ -160,27 +122,6 @@ MultiDeviceExecutor::MultiDeviceExecutor(
// Create the HostIrEvaluator representing the host program
host_ir_executor_ =
std::make_unique<hir::HostIrEvaluator>(std::move(hic), &comm, params);

// Allocator setup
// vals_to_allocate_ stores the tensors that need to be allocated at runtime,
// which correspond to the destination buffers of interdevice communications.
// TODO: reuse allocated buffers and support inplace collectives
// TODO: handle allocation as Host Ir
for (SegmentedGroup* group : staged_fusion->groups()) {
if (isResharding(group->exprs().at(0))) {
NVF_ERROR(group->exprs().at(0)->outputs().size() == 1);
auto val = group->exprs().at(0)->outputs().at(0);
NVF_ERROR(val->isA<TensorView>());
auto tv = val->as<TensorView>();
NVF_ERROR(tv->hasDeviceMesh());
if (tv->getDeviceMesh().has(comm_.deviceId())) {
vals_to_allocate_.push_back(val);
}
}
}
allocator_fusion_ = copyFusionAndChangeOutputs(
staged_fusion->completeFusion(), vals_to_allocate_);
vals_to_allocate_ = clone(vals_to_allocate_);
}

std::vector<at::Tensor> MultiDeviceExecutor::runWithInput(
Expand All @@ -203,13 +144,6 @@ std::vector<at::Tensor> MultiDeviceExecutor::runWithInput(
inputs.at(input_idx);
}

auto allocations =
allocateOutputSpace(inputs, allocator_fusion_.get(), comm()->device());
NVF_ERROR(vals_to_allocate_.size() == allocations.size());
for (auto i : c10::irange(allocations.size())) {
val_to_IValue[vals_to_allocate_.at(i)] = allocations.at(i);
}

return host_ir_executor_->runWithInput(val_to_IValue);
}

Expand Down
6 changes: 0 additions & 6 deletions csrc/multidevice/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ class MultiDeviceExecutor {
std::unique_ptr<Fusion> complete_fusion_;
// holds the HostIrEvaluator used for execution
std::unique_ptr<hir::HostIrEvaluator> host_ir_executor_;
// Cached objects used for MultiDevice allocation
// TODO: remove and handle the allocation through Host Irs
std::unique_ptr<Fusion> allocator_fusion_;
// Cache the tensors that need to be allocated at runtime, which correspond to
// the destination buffers of interdevice communications.
std::vector<Val*> vals_to_allocate_;
};

} // namespace nvfuser
3 changes: 3 additions & 0 deletions csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ std::vector<Communication*> lowerCommunication(Expr* c) {
auto* input_tv = c->input(0)->as<TensorView>();
auto* output_tv = c->output(0)->as<TensorView>();

input_tv->setMemoryType(MemoryType::Global);
output_tv->setMemoryType(MemoryType::Global);

const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
const bool same_mesh = sender_mesh == receiver_mesh;
Expand Down
9 changes: 3 additions & 6 deletions csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,12 @@ void fillTensorWithNan(at::Tensor& t) {
}
}

namespace {
// Allocate an `at::Tensor` for `out_info` or compute it as an alias.
at::Tensor allocateOutput(
at::Tensor allocateTensor(
const GlobalBufferInfo& out_info,
const AliasInfo& alias_info,
const c10::Device& device,
ExpressionEvaluator& ee) {
FUSER_PERF_SCOPE("fusion_executor::allocations::allocateOutput");
FUSER_PERF_SCOPE("fusion_executor::allocations::allocateTensor");
// Handle a fusion with duplicated outputs.
TensorView* out_tv = out_info.tv;
if (ee.isKnown(out_tv)) {
Expand Down Expand Up @@ -312,7 +310,6 @@ at::Tensor allocateOutput(
NVF_THROW("Unrecognized AllocationType.");
}
}
} // namespace

std::vector<at::Tensor> allocateOutputs(
const Fusion* fusion,
Expand Down Expand Up @@ -354,7 +351,7 @@ std::vector<at::Tensor> allocateOutputs(

std::vector<at::Tensor> out_tensors(num_outs);
for (const auto& [out_index, out] : sorted_outs) {
at::Tensor out_tensor = allocateOutput(
at::Tensor out_tensor = allocateTensor(
output_info[out_index], fusion->getOutputAlias(out), device, ee);
// Bind `out_tensor` so
// 1. duplicated outputs map to the same tensor,
Expand Down
7 changes: 7 additions & 0 deletions csrc/runtime/allocations.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ std::pair<std::vector<int64_t>, std::vector<int64_t>> inferShapeOfOutput(
TensorView* tv,
ExpressionEvaluator& expr_eval);

// Allocate an `at::Tensor` for `out_info` or compute it as an alias.
at::Tensor allocateTensor(
const GlobalBufferInfo& out_info,
const AliasInfo& alias_info,
const c10::Device& device,
ExpressionEvaluator& ee);

// Allocate output tensors for a given fusion. Outputs may alias inputs, in
// that case output tensors are shallow copies of the aliased inputs
std::vector<at::Tensor> allocateOutputs(
Expand Down
58 changes: 58 additions & 0 deletions tests/cpp/test_host_irs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,64 @@ TEST_F(IfThenElseTest, HostIr) {
}
}

using AllocationTest = NVFuserTest;

TEST_F(AllocationTest, HostIr) {
const std::vector<int64_t> sizes = {8, 64};

auto hic = std::make_unique<HostIrContainer>();
FusionGuard fg(hic.get());

auto* tv = makeConcreteTensor(sizes);
tv->setMemoryType(MemoryType::Global);
auto* allocate = IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->addOutput(tv);
hic->pushBackTopLevelExprs(allocate);

HostIrEvaluator hie(std::move(hic));

auto outputs = hie.runWithInput({});

EXPECT_EQ(sizes, outputs.at(0).sizes());
}

TEST_F(AllocationTest, inHostForLoop) {
constexpr int64_t kForLoopStop = 4;
const std::vector<int64_t> sizes = {8, 64};

auto hic = std::make_unique<HostIrContainer>();
FusionGuard fg(hic.get());

auto* for_loop = IrBuilder::create<ForLoop>(
/*IterDomain=*/makeContigConcreteTensor({0})->axis(0), // unused
/*index=*/IrBuilder::create<Val>(DataType::Index),
/*start=*/hic->zeroVal(),
/*stop=*/IrBuilder::create<Val>(kForLoopStop, DataType::Index),
/*step=*/hic->oneVal(),
/*vectorize=*/false,
/*vectorize_shift=*/nullptr,
/*unroll_required=*/false,
CircularBufferLoopStage::NotApplicable,
/*circular_buffer_loop_stage_depth=*/0);

TensorView* tv0 = makeConcreteTensor(sizes);
tv0->setMemoryType(MemoryType::Global);
auto* allocate = IrBuilder::create<kir::Allocate>(tv0, MemoryType::Global);
TensorView* tv1 = abs(tv0);

for_loop->body().push_back(allocate);
for_loop->body().push_back(tv1->definition());

hic->pushBackTopLevelExprs(for_loop);
hic->addOutput(tv1);

HostIrEvaluator hie(std::move(hic));

auto outputs = hie.runWithInput({});

EXPECT_EQ(sizes, outputs.at(0).sizes());
}

} // namespace hir

} // namespace nvfuser

0 comments on commit bfd2a6a

Please sign in to comment.