diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 54f1303bc16..6fefa47d68d 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -237,6 +237,25 @@ int64_t getShardedLogicalAxis( return logical_id_to_axis.at(id); } +at::Tensor shardTensor( + at::Tensor tensor, + const int64_t axis, + const DeviceMesh& mesh, + const DeviceIdxType device_id) { + auto i = mesh.idxOf(device_id); + auto extent = tensor.size(axis); + auto nslices = mesh.size(); + NVF_CHECK( + extent % nslices == 0, "Sharded axis must be evenly divisble by mesh"); + auto stride = extent / nslices; + // TODO: returning slice 0 temporarily when device is not in the mesh. + i = (i < 0) ? 0 : i; + // The following slicing is problematic when DID is on an inner split (cf. + // MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and + // it's enforced by getShardedLogicalAxis. + return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); +} + std::vector unshardedSizes( const TensorView* tv, c10::IntArrayRef sizes) { diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index ef88fbdcf80..50500a7a5fc 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -134,6 +134,14 @@ void unshard(TensorView*); // extent if that IterDomain is sharded. int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); +// Shards the input tensor along `axis`. How the tensor gets sliced along `axis` +// is determined by `mesh` and `device_id`. Returns the sharded tensor. +at::Tensor shardTensor( + at::Tensor tensor, + int64_t axis, + const DeviceMesh& mesh, + DeviceIdxType device_id); + // Reorders a TensorView so that the DID parallelized axis are in front. void reorderDIDToFront(TensorView*); diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 9efdebc17ad..9003ec09351 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -314,6 +314,17 @@ void FusionDefinition::finalizeSchedule( // Users can access schedule objects after scheduling the fusion. } +void FusionDefinition::setupMultideviceSchedule() { + // FusionDefinition.multidevice_schedule may create new Exprs (e.g. DID + // splits), which will be added to the presched fusion. + prev_fusion_ = FusionGuard::getCurFusion(); + FusionGuard::setCurFusion(preschedFusion()); +} + +void FusionDefinition::finalizeMultideviceSchedule() { + FusionGuard::setCurFusion(prev_fusion_); +} + void FusionDefinition::print(std::ostream& os) const { if (id().has_value()) { os << "\ndef nvfuser_fusion_id" << id().value(); diff --git a/csrc/python_frontend/fusion_definition.h b/csrc/python_frontend/fusion_definition.h index c359352c565..6157704f86b 100644 --- a/csrc/python_frontend/fusion_definition.h +++ b/csrc/python_frontend/fusion_definition.h @@ -184,6 +184,11 @@ class NVF_API FusionDefinition : public FusionState { //! Finalized use scheduling of a fusion //! resets FusionGuard, lowers IR to a kernel, compiles kernel NVF_API void finalizeSchedule(const at::ArrayRef& inputs); + //! A hook that gets called right before + //! FusionDefinition.multidevice_schedule. + NVF_API void setupMultideviceSchedule(); + //! A hook that gets called right after FusionDefinition.multidevice_schedule. + NVF_API void finalizeMultideviceSchedule(); //! Prints a python function representing the definition NVF_API void print(std::ostream& os) const; //! Executes a fusion if a valid definition or cache lookup occurred prior diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index d72115b8043..c95397c7a78 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1013,11 +1013,23 @@ void initNvFuserPythonBindings(PyObject* module) { scalar_class.def(pybind11::self != pybind11::self); py::class_ device_mesh_class(nvfuser, "DeviceMesh"); + device_mesh_class.def(py::init>()); device_mesh_class.def("__repr__", [](const DeviceMesh& self) { std::stringstream ss; ss << self; return ss.str(); }); + device_mesh_class.def( + "shard_tensor", + [](const DeviceMesh& self, + at::Tensor tensor, + const int64_t axis, + int64_t device_id) -> at::Tensor { + return shardTensor(tensor, axis, self, device_id); + }, + py::arg("tensor"), + py::arg("axis"), + py::arg("device_id")); py::class_ vector_class(nvfuser, "Vector"); vector_class.def("__repr__", [](Vector& self) { @@ -1091,6 +1103,12 @@ void initNvFuserPythonBindings(PyObject* module) { // Mark the end of a schedule inst::Trace::instance()->endEvent(nullptr); }) + .def( + "_setup_multidevice_schedule", + [](FusionDefinition& self) { self.setupMultideviceSchedule(); }) + .def( + "_finalize_multidevice_schedule", + [](FusionDefinition& self) { self.finalizeMultideviceSchedule(); }) .def("inputs", [](FusionDefinition& self) { return self.inputs(); }) .def("outputs", [](FusionDefinition& self) { return self.outputs(); }) .def("extents", [](FusionDefinition& self) { return self.extents(); }) @@ -3575,13 +3593,6 @@ void initNvFuserPythonBindings(PyObject* module) { }, py::return_value_policy::reference); //! experimental API for multidevice support - nvf_sched.def( - "_create_device_mesh", - [](FusionDefinition::SchedOperators& self, - const std::vector& devices) { return DeviceMesh(devices); }, - py::arg("devices"), - py::return_value_policy::reference); - //! experimental API for multidevice support nvf_sched.def( "_set_device_mesh", [](FusionDefinition::SchedOperators& self, @@ -3596,7 +3607,6 @@ void initNvFuserPythonBindings(PyObject* module) { }, py::arg("tensor"), py::arg("mesh")); - //! experimental API for multidevice support nvf_sched.def( "parallelize", [](FusionDefinition::SchedOperators& self, @@ -3683,6 +3693,18 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("dim"), py::arg("factor"), py::arg("inner_split") = true); + nvf_sched.def( + "set_allocation_as_loop", + [](FusionDefinition::SchedOperators& self, Tensor arg) { + FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto* tv = fd->getFusionState(arg.index)->template as(); + tv->setAllocationDomain(tv->getLoopDomain(), true); + }, + py::arg("arg")); nvf_sched.def( "cache_after", [](FusionDefinition::SchedOperators& self, diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 967be681b7f..d65198c21a7 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -284,7 +284,9 @@ def execute( # # Note: there's a plan to embed multidevice schedules into FusionDefinition # as annotating nodes. This may eventually replace `multidevice_schedule`. + self._setup_multidevice_schedule() self.multidevice_schedule() + self._finalize_multidevice_schedule() # If schedule is defined by child class and schedule is not defined for # inputs, make a schedule. diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index 22897dc5311..e67b507b5f2 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -136,21 +136,10 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) { at::Tensor MultiDeviceTest::shardTensor( at::Tensor tensor, - int64_t axis, + const int64_t axis, const DeviceMesh& mesh) { const auto device_id = communicator_->deviceId(); - auto i = mesh.idxOf(device_id); - auto extent = tensor.size(axis); - auto nslices = mesh.size(); - NVF_CHECK( - extent % nslices == 0, "Sharded axis must be evenly divisble by mesh"); - auto stride = extent / nslices; - // TODO: returning slice 0 temporarily when device is not in the mesh. - i = (i < 0) ? 0 : i; - // The following slicing is problematic when DID is on an inner split (cf. - // MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and - // it's enforced by getShardedLogicalAxis. - return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); + return nvfuser::shardTensor(tensor, axis, mesh, device_id); } } // namespace nvfuser diff --git a/tests/cpp/multidevice.h b/tests/cpp/multidevice.h index 24d1c323215..1831eb46bbb 100644 --- a/tests/cpp/multidevice.h +++ b/tests/cpp/multidevice.h @@ -33,6 +33,7 @@ class MultiDeviceTest : public NVFuserTest { // tensor. Currently, we don't support this, so for now it returns a slice. at::Tensor shardTensor(at::Tensor tensor, TensorView* tv); + // A lower-level helper that doesn't require a TensorView. at::Tensor shardTensor( at::Tensor tensor, int64_t axis, diff --git a/tests/python/mpi_fixtures.py b/tests/python/mpi_fixtures.py index d0f80f20a62..2193b29a0b9 100644 --- a/tests/python/mpi_fixtures.py +++ b/tests/python/mpi_fixtures.py @@ -5,6 +5,7 @@ import os import pytest import torch +import nvfuser from mpi4py import MPI @@ -38,6 +39,16 @@ def local_rank(self): def barrier(self): self._communicator.barrier() + def shard_tensor( + self, t: torch.Tensor, dim: int, mesh: nvfuser.DeviceMesh + ) -> torch.Tensor: + assert t.is_cpu, ( + "This is not strictly required but it's a general good practice " + "for unit tests to create unsharded data on CPU to reduce GPU " + "memory footprint." + ) + return mesh.shard_tensor(t, dim, self.rank).cuda(self.local_rank) + @pytest.fixture(scope="session") def mpi_test(): diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py new file mode 100644 index 00000000000..d0cea846669 --- /dev/null +++ b/tests/python/test_communication.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import torch + +import mpi_fixtures +import nvfuser +from nvfuser import DataType, FusionDefinition + + +mpi_test = mpi_fixtures.mpi_test + + +@pytest.mark.mpi +def test_allgather(mpi_test): + num_devices = mpi_test.size + mesh = nvfuser.DeviceMesh(range(num_devices)) + + class Model(FusionDefinition): + def definition(self): + self.inp = self.define_tensor( + (num_devices * 4,), contiguity=True, dtype=DataType.Float + ) + self.out = self.ops.set(self.inp) + self.add_output(self.out) + + def multidevice_schedule(self): + self.sched._set_device_mesh(self.inp, mesh) + self.sched._set_device_mesh(self.out, mesh) + + self.sched.split(self.inp, 0, num_devices, False) + self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) + self.sched.set_allocation_as_loop(self.inp) + + self.sched.split(self.out, 0, num_devices, False) + self.sched.set_allocation_as_loop(self.out) + + unsharded = torch.randn(num_devices * 4) + sharded = mpi_test.shard_tensor(unsharded, 0, mesh) + + fd = Model() + outputs = fd.execute([sharded]) + torch.testing.assert_close(outputs[0].cpu(), unsharded) diff --git a/tests/python/test_multidevice.py b/tests/python/test_multidevice.py index a54353e38a5..953561d9dd2 100644 --- a/tests/python/test_multidevice.py +++ b/tests/python/test_multidevice.py @@ -33,12 +33,7 @@ def test_sizes_and_ranks(mpi_test): @pytest.mark.mpi def test_pointwise(mpi_test): num_devices = mpi_test.size - rank = mpi_test.rank - - torch.cuda.set_device(mpi_test.local_rank) - - unsharded_input = torch.randn(num_devices, 4, device="cuda") - sharded_input = unsharded_input[rank : rank + 1] + mesh = nvfuser.DeviceMesh(range(num_devices)) class Model(FusionDefinition): def definition(self): @@ -50,15 +45,17 @@ def definition(self): self.add_output(self.t2) def multidevice_schedule(self): - mesh = self.sched._create_device_mesh(range(num_devices)) self.sched._set_device_mesh(self.t0, mesh) self.sched._set_device_mesh(self.t1, mesh) self.sched._set_device_mesh(self.t2, mesh) self.sched.parallelize(self.t0, 0, nvfuser.ParallelType.mesh_x) + unsharded_input = torch.randn(num_devices, 4) + sharded_input = mpi_test.shard_tensor(unsharded_input, 0, mesh) + fd = Model() outputs = fd.execute([sharded_input]) - torch.testing.assert_close(outputs[0], unsharded_input.relu() * 2) + torch.testing.assert_close(outputs[0].cpu(), unsharded_input.relu() * 2) @pytest.mark.mpi @@ -80,7 +77,7 @@ def definition(self): self.add_output(out) def multidevice_schedule(self): - mesh = self.sched._create_device_mesh(range(self._num_devices)) + mesh = nvfuser.DeviceMesh(range(self._num_devices)) for t in [self.inp, self.weight, self.bias]: self.sched._set_device_mesh(t, mesh) for t in [self.weight, self.bias]: @@ -135,7 +132,7 @@ def definition(self) -> None: self.add_output(in_grad) def multidevice_schedule(self) -> None: - mesh = self.sched._create_device_mesh(range(d)) + mesh = nvfuser.DeviceMesh(range(d)) for t in [self.out_grad, self.weight]: self.sched._set_device_mesh(t, mesh) self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x) @@ -230,7 +227,7 @@ def definition(self) -> None: self.add_output(grad) def multidevice_schedule(self) -> None: - mesh = self.sched._create_device_mesh(range(d)) + mesh = nvfuser.DeviceMesh(range(d)) for t in [self.q, self.k, self.v, self.out_grad]: self.sched._set_device_mesh(t, mesh) self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x) @@ -649,7 +646,7 @@ def definition(self) -> None: self.add_output(out) def multidevice_schedule(self): - mesh = self.sched._create_device_mesh(range(self._num_devices)) + mesh = nvfuser.DeviceMesh(range(self._num_devices)) # Assign the mesh to inputs and weights. nvFuser will propagate it to # downstream tensors. for in_tv in [ @@ -1241,7 +1238,7 @@ def definition(self) -> None: self.add_output(inp_grad) def multidevice_schedule(self): - mesh = self.sched._create_device_mesh(range(self._num_devices)) + mesh = nvfuser.DeviceMesh(range(self._num_devices)) for in_tv in [ self.mlp_linear0_out, self.out_grad,