Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the normalization scheduler to accept DID loop split. #3853

Open
wants to merge 1 commit into
base: bug3817
Choose a base branch
from

Conversation

wujingyue
Copy link
Collaborator

I'm sure we'll need more tests to be confident, but this incremental PR feels good!

@@ -32,16 +32,6 @@ NVF_API bool distributedEnabled() {

namespace {

std::unordered_set<IterDomain*> getShardedIterDomains(TensorView* tv) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used

@wujingyue wujingyue requested a review from naoyam February 8, 2025 01:03
Copy link

github-actions bot commented Feb 8, 2025

Review updated until commit 5611117

Description

  • Added function to get sharded loop axis based on parallel type.

  • Updated reduction scheduler to handle sharded axes correctly.

  • Added test case for division by sum with sharded tensors.


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Add sharded loop axis function                                                     

csrc/multidevice/utils.cpp

  • Removed unused function getShardedIterDomains.
  • Added function getShardedLoopAxis to find the loop axis parallelized
    on a given parallel type.
  • +15/-10 
    reduction_utils.cpp
    Update reduction scheduler for sharded axes                           

    csrc/scheduler/reduction_utils.cpp

  • Updated scheduleReductionTV to use getShardedLoopAxis for determining
    sharded axes.
  • Added error checks for sharded axes in 3D scheduling.
  • +11/-4   
    utils.h
    Declare sharded loop axis function                                             

    csrc/multidevice/utils.h

    • Added declaration for getShardedLoopAxis.
    +4/-0     
    Tests
    test_multidevice_sharding.cpp
    Add test for division by sum with sharding                             

    tests/cpp/test_multidevice_sharding.cpp

  • Added test case DivideBySum to verify division by sum with sharded
    tensors.
  • +42/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Assumption Validation

    The assumption that the sharded axis is always the outermost domain should be validated with more test cases, especially for different mesh configurations.

    int64_t sharded_axis = getShardedLoopAxis(reduction_tv, ParallelType::DIDx);
    if (sharded_axis >= 0) {
      NVF_ERROR(
          sharded_axis == 0,
          "Expect 1D mesh and DIDx only appear outermost in loop, but found: ",
          reduction_tv->getLoopDomain());
    }
    Error Handling

    The error handling in getShardedLoopAxis should be reviewed to ensure it covers all edge cases, such as when no device dimension is found.

    int64_t getShardedLoopAxis(
        const TensorView* tv,
        const ParallelType parallel_type) {
      NVF_ERROR(
          isParallelTypeDeviceDim(parallel_type),
          "Expect a DID but found: ",
          parallel_type);
      for (int64_t i : c10::irange(tv->nDims())) {
        if (tv->getLoopDomain()[i]->isDeviceDim()) {
          return i;
        }
      }
      return -1;
    Test Coverage

    The new test DivideBySum should be expanded to cover more scenarios, including edge cases and different mesh sizes.

    TEST_F(MultiDeviceTest, DivideBySum) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int64_t d = communicator_->size();
    
      // [b, h, s, s]
      TensorView* x = makeContigTensor(4);
      TensorView* sum_x = sum(x, {-1});
      TensorView* sum_x_broadcasted = broadcast(sum_x, {false, false, false, true});
      TensorView* y = div(x, sum_x_broadcasted);
      fusion->addInput(x);
      fusion->addOutput(y);
    
      auto mesh = DeviceMesh::createForNumDevices(d);
      for (auto* tv : {x, sum_x, sum_x_broadcasted, y}) {
        tv->setDeviceMesh(mesh);
        tv->split(1, d, /*inner_split=*/false);
        tv->axis(1)->parallelize(ParallelType::DIDx);
        tv->reorder({{1, 0}});
      }
      for (auto* tv : {x, y}) {
        tv->setAllocationDomain(tv->getLoopDomain(), true);
      }
    
      const int64_t b = 2;
      const int64_t h = d * 3;
      const int64_t s = 5;
      at::Tensor unsharded_x_tensor = at::randint(5, {b, h, s, s}, tensor_options);
      at::Tensor x_tensor = shardTensor(unsharded_x_tensor, x);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor y_tensor = executor_cache.runFusionWithInputs({x_tensor})[0];
      testValidate(
          executor_cache.fusion(),
          {y_tensor},
          {x_tensor},
          {x_tensor / x_tensor.sum(-1, true)},
          __LINE__,
          __FILE__);
    }

    @wujingyue wujingyue requested a review from Priya2698 February 8, 2025 01:03
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue changed the base branch from wjy/gdb to bug3817 February 8, 2025 07:50
    @wujingyue
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant