Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Privatize up-cast ops for better segmentation (#3776)
## Problem This is to address yet another segmentation problem with RoPE. In particular, in Phi3 forward, there's two bfloat-to-float cast ops that are consumed by two segments. Find `T49` and `T36` below: ![phi3_fwd](https://github.com/user-attachments/assets/a3e1c256-6e58-4028-8fd2-a6725e5a810f) They are consumed by two segments, one with the blue color and another with the light purple color (not the one spanning vertically in the center of the graph). The problem here is that the upcast ops are grouped into the blue segment and their output float tensors are fed into the light purple segment. Specifically, we get this segment: ``` g{(resize) group id: 6 inputs: T0_g___bfloat[bS0{1}, iS1{8192}, iS2{9216}] __bfloat T34_g___bfloat[bS121{1}, bS122{1 ex 32}, iS123{8192}, iS124{96}] __bfloat T47_g___bfloat[bS177{1}, bS178{1 ex 32}, iS179{8192}, iS180{96}] __bfloat outputs: T36_g_float[bS129{1}, bS130{1 ex 32}, iS131{8192}, iS132{96}] float T49_g_float[bS185{1}, bS186{1 ex 32}, iS187{8192}, iS188{96}] float T52_g___bfloat[bS197{1}, iS198{32}, iS199{8192}, iS200{96}] __bfloat ``` which is followed by: ``` g{(resize) group id: 7 inputs: T0_g___bfloat[bS0{1}, iS1{8192}, iS2{9216}] __bfloat T36_g_float[bS129{1}, bS130{1 ex 32}, iS131{8192}, iS132{96}] float T49_g_float[bS185{1}, bS186{1 ex 32}, iS187{8192}, iS188{96}] float outputs: T66_g___bfloat[bS257{1}, iS258{32}, iS259{8192}, iS260{96}] __bfloat ``` Notice that the first segment produces `T36` and `T49`, which are just upcast of `T34` and `T47`, respectively, and then they are inputs of the following segment. This is not ideal. The second segment should just use `T34` and `T47` directly, and by doing so the first segment would not need to produce `T34` and `T47` as segment outputs. More concretely, in the current segmentation, there are two reads of bf16 tensors (`T34` and `T47`), two writes of fp32 tensor (`T36` and `T47`), and two reads of fp32 tensors (`T36` and `T47`). What we could do instead is just two reads of bf16 tensors (`T34` and `T47`) and another two reads of the same tensors. The fusion segmenter already addresses this problem partially by forwarding unary ops, but only for unary ops using fusion inputs, which doesn't apply to the Phi3 case. ## Fix The above problem with Phi3 wouldn't happen if `T49` and `T36` are not shared by the two segments. So, we first privatize all upcast tensors. This is done after the initial unsegmented trial and before the segmentation loop. https://github.com/NVIDIA/Fuser/pull/3776/files#diff-e2f2ad44a6dc03e4ad8e5f0f047be25eb1c142add431d48c1e046c968a577f3bR3958 That's all for the Phi3 case, but this privatization isn't necessary if the two segments were actually fused (which we don't support yet). If that actually happened, the fused segment would have something like: ``` T2 = bf16ToFp32(T0); T3 = bf16ToFp32(T0); T6 = T2 + T3 ``` Instead of: ``` T2 = bf16ToFp32(T0); T6 = T2 + T2 ``` This is functionally correct and shouldn't have any perf issue either, but just in case, we revert the privatization in the final segments. ## Perf benefit Current resize schedule on H100: ``` NVFUSER_ENABLE=resize_scheduler pytest benchmarks/python/test_rope.py --benchmark-thunder -k 'hf_phi3_rope and fwd' ---------------------------------------------------------------------------- benchmark: 1 tests --------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_rope_fwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 129.9170 132.9290 131.3976 0.7926 131.2950 0.7330 2;1 7.6105 10 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` With this PR: ``` -------------------------------------------------------------------------- benchmark: 1 tests ------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_rope_fwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 97.0230 99.9030 98.9724 0.7649 99.1510 0.4780 2;1 10.1038 10 1 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` It's also effective even without the resize scheduler. TOT: ``` ---------------------------------------------------------------------------- benchmark: 1 tests --------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_rope_fwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 195.1030 196.3500 195.7106 0.3948 195.6955 0.5120 3;0 5.1096 10 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` With this PR: ``` ---------------------------------------------------------------------------- benchmark: 1 tests --------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_rope_fwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 141.1850 142.4950 141.7790 0.4813 141.7605 0.9600 5;0 7.0532 10 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` --------- Co-authored-by: Liqiang Lu <liqiangxl@gmail.com>
- Loading branch information