Skip to content

Commit

Permalink
DID-parallelize a loop split in Python. (#3503)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored Dec 10, 2024
1 parent bfd2a6a commit 98352c4
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 34 deletions.
19 changes: 19 additions & 0 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,25 @@ int64_t getShardedLogicalAxis(
return logical_id_to_axis.at(id);
}

at::Tensor shardTensor(
at::Tensor tensor,
const int64_t axis,
const DeviceMesh& mesh,
const DeviceIdxType device_id) {
auto i = mesh.idxOf(device_id);
auto extent = tensor.size(axis);
auto nslices = mesh.size();
NVF_CHECK(
extent % nslices == 0, "Sharded axis must be evenly divisble by mesh");
auto stride = extent / nslices;
// TODO: returning slice 0 temporarily when device is not in the mesh.
i = (i < 0) ? 0 : i;
// The following slicing is problematic when DID is on an inner split (cf.
// MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and
// it's enforced by getShardedLogicalAxis.
return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
Expand Down
8 changes: 8 additions & 0 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ void unshard(TensorView*);
// extent if that IterDomain is sharded.
int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type);

// Shards the input tensor along `axis`. How the tensor gets sliced along `axis`
// is determined by `mesh` and `device_id`. Returns the sharded tensor.
at::Tensor shardTensor(
at::Tensor tensor,
int64_t axis,
const DeviceMesh& mesh,
DeviceIdxType device_id);

// Reorders a TensorView so that the DID parallelized axis are in front.
void reorderDIDToFront(TensorView*);

Expand Down
11 changes: 11 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,17 @@ void FusionDefinition::finalizeSchedule(
// Users can access schedule objects after scheduling the fusion.
}

void FusionDefinition::setupMultideviceSchedule() {
// FusionDefinition.multidevice_schedule may create new Exprs (e.g. DID
// splits), which will be added to the presched fusion.
prev_fusion_ = FusionGuard::getCurFusion();
FusionGuard::setCurFusion(preschedFusion());
}

void FusionDefinition::finalizeMultideviceSchedule() {
FusionGuard::setCurFusion(prev_fusion_);
}

void FusionDefinition::print(std::ostream& os) const {
if (id().has_value()) {
os << "\ndef nvfuser_fusion_id" << id().value();
Expand Down
5 changes: 5 additions & 0 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ class NVF_API FusionDefinition : public FusionState {
//! Finalized use scheduling of a fusion
//! resets FusionGuard, lowers IR to a kernel, compiles kernel
NVF_API void finalizeSchedule(const at::ArrayRef<c10::IValue>& inputs);
//! A hook that gets called right before
//! FusionDefinition.multidevice_schedule.
NVF_API void setupMultideviceSchedule();
//! A hook that gets called right after FusionDefinition.multidevice_schedule.
NVF_API void finalizeMultideviceSchedule();
//! Prints a python function representing the definition
NVF_API void print(std::ostream& os) const;
//! Executes a fusion if a valid definition or cache lookup occurred prior
Expand Down
38 changes: 30 additions & 8 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,23 @@ void initNvFuserPythonBindings(PyObject* module) {
scalar_class.def(pybind11::self != pybind11::self);

py::class_<DeviceMesh> device_mesh_class(nvfuser, "DeviceMesh");
device_mesh_class.def(py::init<std::vector<int64_t>>());
device_mesh_class.def("__repr__", [](const DeviceMesh& self) {
std::stringstream ss;
ss << self;
return ss.str();
});
device_mesh_class.def(
"shard_tensor",
[](const DeviceMesh& self,
at::Tensor tensor,
const int64_t axis,
int64_t device_id) -> at::Tensor {
return shardTensor(tensor, axis, self, device_id);
},
py::arg("tensor"),
py::arg("axis"),
py::arg("device_id"));

py::class_<Vector> vector_class(nvfuser, "Vector");
vector_class.def("__repr__", [](Vector& self) {
Expand Down Expand Up @@ -1091,6 +1103,12 @@ void initNvFuserPythonBindings(PyObject* module) {
// Mark the end of a schedule
inst::Trace::instance()->endEvent(nullptr);
})
.def(
"_setup_multidevice_schedule",
[](FusionDefinition& self) { self.setupMultideviceSchedule(); })
.def(
"_finalize_multidevice_schedule",
[](FusionDefinition& self) { self.finalizeMultideviceSchedule(); })
.def("inputs", [](FusionDefinition& self) { return self.inputs(); })
.def("outputs", [](FusionDefinition& self) { return self.outputs(); })
.def("extents", [](FusionDefinition& self) { return self.extents(); })
Expand Down Expand Up @@ -3575,13 +3593,6 @@ void initNvFuserPythonBindings(PyObject* module) {
},
py::return_value_policy::reference);
//! experimental API for multidevice support
nvf_sched.def(
"_create_device_mesh",
[](FusionDefinition::SchedOperators& self,
const std::vector<int64_t>& devices) { return DeviceMesh(devices); },
py::arg("devices"),
py::return_value_policy::reference);
//! experimental API for multidevice support
nvf_sched.def(
"_set_device_mesh",
[](FusionDefinition::SchedOperators& self,
Expand All @@ -3596,7 +3607,6 @@ void initNvFuserPythonBindings(PyObject* module) {
},
py::arg("tensor"),
py::arg("mesh"));
//! experimental API for multidevice support
nvf_sched.def(
"parallelize",
[](FusionDefinition::SchedOperators& self,
Expand Down Expand Up @@ -3683,6 +3693,18 @@ void initNvFuserPythonBindings(PyObject* module) {
py::arg("dim"),
py::arg("factor"),
py::arg("inner_split") = true);
nvf_sched.def(
"set_allocation_as_loop",
[](FusionDefinition::SchedOperators& self, Tensor arg) {
FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop");
NVF_CHECK(
self.validUse(),
"Attempting to use a SchedOperators Op prior to definition!");
FusionDefinition* fd = self.fusion_definition;
auto* tv = fd->getFusionState(arg.index)->template as<TensorView>();
tv->setAllocationDomain(tv->getLoopDomain(), true);
},
py::arg("arg"));
nvf_sched.def(
"cache_after",
[](FusionDefinition::SchedOperators& self,
Expand Down
2 changes: 2 additions & 0 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def execute(
#
# Note: there's a plan to embed multidevice schedules into FusionDefinition
# as annotating nodes. This may eventually replace `multidevice_schedule`.
self._setup_multidevice_schedule()
self.multidevice_schedule()
self._finalize_multidevice_schedule()

# If schedule is defined by child class and schedule is not defined for
# inputs, make a schedule.
Expand Down
15 changes: 2 additions & 13 deletions tests/cpp/multidevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,10 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) {

at::Tensor MultiDeviceTest::shardTensor(
at::Tensor tensor,
int64_t axis,
const int64_t axis,
const DeviceMesh& mesh) {
const auto device_id = communicator_->deviceId();
auto i = mesh.idxOf(device_id);
auto extent = tensor.size(axis);
auto nslices = mesh.size();
NVF_CHECK(
extent % nslices == 0, "Sharded axis must be evenly divisble by mesh");
auto stride = extent / nslices;
// TODO: returning slice 0 temporarily when device is not in the mesh.
i = (i < 0) ? 0 : i;
// The following slicing is problematic when DID is on an inner split (cf.
// MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and
// it's enforced by getShardedLogicalAxis.
return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
return nvfuser::shardTensor(tensor, axis, mesh, device_id);
}

} // namespace nvfuser
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/multidevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MultiDeviceTest : public NVFuserTest {
// tensor. Currently, we don't support this, so for now it returns a slice.
at::Tensor shardTensor(at::Tensor tensor, TensorView* tv);

// A lower-level helper that doesn't require a TensorView.
at::Tensor shardTensor(
at::Tensor tensor,
int64_t axis,
Expand Down
11 changes: 11 additions & 0 deletions tests/python/mpi_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pytest
import torch
import nvfuser

from mpi4py import MPI

Expand Down Expand Up @@ -38,6 +39,16 @@ def local_rank(self):
def barrier(self):
self._communicator.barrier()

def shard_tensor(
self, t: torch.Tensor, dim: int, mesh: nvfuser.DeviceMesh
) -> torch.Tensor:
assert t.is_cpu, (
"This is not strictly required but it's a general good practice "
"for unit tests to create unsharded data on CPU to reduce GPU "
"memory footprint."
)
return mesh.shard_tensor(t, dim, self.rank).cuda(self.local_rank)


@pytest.fixture(scope="session")
def mpi_test():
Expand Down
45 changes: 45 additions & 0 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import torch

import mpi_fixtures
import nvfuser
from nvfuser import DataType, FusionDefinition


mpi_test = mpi_fixtures.mpi_test


@pytest.mark.mpi
def test_allgather(mpi_test):
num_devices = mpi_test.size
mesh = nvfuser.DeviceMesh(range(num_devices))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(num_devices * 4,), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.set(self.inp)
self.add_output(self.out)

def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.split(self.inp, 0, num_devices, False)
self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.inp)

self.sched.split(self.out, 0, num_devices, False)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(num_devices * 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded)
23 changes: 10 additions & 13 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ def test_sizes_and_ranks(mpi_test):
@pytest.mark.mpi
def test_pointwise(mpi_test):
num_devices = mpi_test.size
rank = mpi_test.rank

torch.cuda.set_device(mpi_test.local_rank)

unsharded_input = torch.randn(num_devices, 4, device="cuda")
sharded_input = unsharded_input[rank : rank + 1]
mesh = nvfuser.DeviceMesh(range(num_devices))

class Model(FusionDefinition):
def definition(self):
Expand All @@ -50,15 +45,17 @@ def definition(self):
self.add_output(self.t2)

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(num_devices))
self.sched._set_device_mesh(self.t0, mesh)
self.sched._set_device_mesh(self.t1, mesh)
self.sched._set_device_mesh(self.t2, mesh)
self.sched.parallelize(self.t0, 0, nvfuser.ParallelType.mesh_x)

unsharded_input = torch.randn(num_devices, 4)
sharded_input = mpi_test.shard_tensor(unsharded_input, 0, mesh)

fd = Model()
outputs = fd.execute([sharded_input])
torch.testing.assert_close(outputs[0], unsharded_input.relu() * 2)
torch.testing.assert_close(outputs[0].cpu(), unsharded_input.relu() * 2)


@pytest.mark.mpi
Expand All @@ -80,7 +77,7 @@ def definition(self):
self.add_output(out)

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(self._num_devices))
mesh = nvfuser.DeviceMesh(range(self._num_devices))
for t in [self.inp, self.weight, self.bias]:
self.sched._set_device_mesh(t, mesh)
for t in [self.weight, self.bias]:
Expand Down Expand Up @@ -135,7 +132,7 @@ def definition(self) -> None:
self.add_output(in_grad)

def multidevice_schedule(self) -> None:
mesh = self.sched._create_device_mesh(range(d))
mesh = nvfuser.DeviceMesh(range(d))
for t in [self.out_grad, self.weight]:
self.sched._set_device_mesh(t, mesh)
self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x)
Expand Down Expand Up @@ -230,7 +227,7 @@ def definition(self) -> None:
self.add_output(grad)

def multidevice_schedule(self) -> None:
mesh = self.sched._create_device_mesh(range(d))
mesh = nvfuser.DeviceMesh(range(d))
for t in [self.q, self.k, self.v, self.out_grad]:
self.sched._set_device_mesh(t, mesh)
self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x)
Expand Down Expand Up @@ -649,7 +646,7 @@ def definition(self) -> None:
self.add_output(out)

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(self._num_devices))
mesh = nvfuser.DeviceMesh(range(self._num_devices))
# Assign the mesh to inputs and weights. nvFuser will propagate it to
# downstream tensors.
for in_tv in [
Expand Down Expand Up @@ -1241,7 +1238,7 @@ def definition(self) -> None:
self.add_output(inp_grad)

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(self._num_devices))
mesh = nvfuser.DeviceMesh(range(self._num_devices))
for in_tv in [
self.mlp_linear0_out,
self.out_grad,
Expand Down

0 comments on commit 98352c4

Please sign in to comment.