Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move getTvsWithDifferentSharding to cpp #3858

Merged
merged 1 commit into from
Feb 9, 2025
Merged

Move getTvsWithDifferentSharding to cpp #3858

merged 1 commit into from
Feb 9, 2025

Conversation

wujingyue
Copy link
Collaborator

The function no longer needs to be template and is large enough to justify putting in cpp.

@wujingyue wujingyue requested a review from zasdfgbnm February 9, 2025 01:22
@wujingyue
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 9, 2025

Description

  • Moved getTvsWithDifferentSharding from header to source file

  • Removed template parameter from function signature


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Added `getTvsWithDifferentSharding` implementation             

csrc/multidevice/utils.cpp

  • Added getTvsWithDifferentSharding function implementation
+35/-0   
utils.h
Updated `getTvsWithDifferentSharding` signature                   

csrc/multidevice/utils.h

  • Removed getTvsWithDifferentSharding function implementation
  • Updated function signature to use const std::vector& instead of
    template
  • +1/-33   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Performance Impact

    Moving the function to C++ may have implications on performance. Ensure that the move does not introduce any performance regressions.

    std::unordered_set<TensorView*> getTvsWithDifferentSharding(
        TensorView* ref,
        const std::vector<TensorView*>& tvs) {
      std::unordered_set<TensorView*> ret;
      const auto& reference_dom = ref->getLoopDomain();
      FusionGuard fg(ref->fusion());
      auto ca_map = ComputeAtMap(FusionGuard::getCurFusion());
      std::unordered_map<IterDomain*, IterDomain*> concrete_to_reference_map;
      for (auto id : reference_dom) {
        auto ca_id =
            ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE);
        concrete_to_reference_map[ca_id] = id;
      }
    
      for (TensorView* tv : tvs) {
        if (ref->getDeviceMesh().vector() != tv->getDeviceMesh().vector()) {
          ret.insert(tv);
          continue;
        }
        for (auto id : tv->getLoopDomain()) {
          auto ca_id =
              ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE);
          if (concrete_to_reference_map.count(ca_id) > 0) {
            auto ref_id = concrete_to_reference_map.at(ca_id);
            if ((ref_id->isDeviceDim() || id->isDeviceDim()) &&
                ref_id->getParallelType() != id->getParallelType()) {
              ret.insert(tv);
              break;
            }
          }
        }
      }
      return ret;
    }
    Template Removal

    The function was previously templated. Verify that removing the template does not affect the flexibility or correctness of the function.

    std::unordered_set<TensorView*> getTvsWithDifferentSharding(
        TensorView* ref,
        const std::vector<TensorView*>& tvs);

    @wujingyue wujingyue merged commit 39bc83a into main Feb 9, 2025
    44 of 46 checks passed
    @wujingyue wujingyue deleted the wjy/move branch February 9, 2025 06:28
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants