Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Feature] Implement native fused MoE layer #121

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

Check failure on line 1 in tests/ops/test_fused_moe.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/ops/test_fused_moe.py:1:8: F401 `pytest` imported but unused
import torch
from vllm_ascend.ops.fused_moe import fused_moe


def test_fused_moe():
# Since we are using native PyTorch operations in the function, the most reliable ground truth
# for comparison is the manually computed output. By using hardcoded data, we can ensure
# that the function produces the expected results and validate its correctness against a known reference.

# Step 1: Constructing inputs
hidden_states = torch.tensor([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]])

# w1: [3, 4, 3] (num_experts=3, intermediate_size*2=4, hidden_size=3)
w1 = torch.tensor(
[
[[1.0, 0.0, -1.0], [2.0, 1.0, 0.0], [1.0, 1.0, -1.0], [1.0, -1.0, 1.0]],
[[-1.0, 1.0, 1.0], [1.0, -1.0, 1.0], [2.0, -2.0, 2.0], [1.0, 0.0, -1.0]],
[[-2.0, -1.0, 1.0], [2.0, -1.0, 1.0], [-1.0, 2.0, 1.0], [1.0, 1.0, -1.0]],
]
)

# w2: [3, 3, 2] (num_experts=3, hidden_size=3, intermediate_size=2)
w2 = torch.tensor(
[
[[1.0, 0.5], [2.0, -1.0], [0.0, 1.0]],
[[1.0, 1.0], [-1.0, 1.0], [1.0, -0.0]],
[[-2.0, 1.0], [1.0, -1.0], [2.0, 1.0]],
]
)

# gating_output: [2, 3] (num_tokens=2, num_experts=3)
gating_output = torch.tensor([[0.0, 0.5, 0.5], [0.5, 0.5, 0.0]])

topk = 2

global_num_experts = 3

# Only has the first two experts
expert_map = torch.tensor([0, 1, -1])

renormalize = False

use_grouped_topk = False

# Step 2: Expected output calculation

# We use topk=2, which means we select the top 2 experts based on gating_output.
# For sample 1, gating_output = [0.1, 0.7, 0.2], topk_weights = [0.7, 0.2], selected experts = 1, 2
# For sample 2, gating_output = [0.5, 0.4, 0.1], topk_weights = [0.5, 0.4], selected experts = 0, 1

# 1. Calculate linear transformation of hidden_states with w1[0] -> F.linear(hidden_states, w1[0])
# 2. Apply gating function to get gate values -> F.silu(x[:, :intermediate_size])
# 3. Apply second linear transformation with w2[0] -> F.linear(x, w2[0])
# 4. Use the topk_weights for each sample and add the weighted outputs of experts 1 and 2

expected_hidden_states = torch.tensor([[4.6763, -7.3797, 6.0280], [7.1232, 0.6220, 6.1364]])

# Step 3: Running the fused_moe function
final_output = fused_moe(
hidden_states, w1, w2, gating_output, topk, global_num_experts, expert_map, renormalize, use_grouped_topk
)

# Step 4: Check the shape and values (this should match the expected result you computed manually)
assert (
final_output.shape == hidden_states.shape
), f"Expected shape {hidden_states.shape}, but got {final_output.shape}"

assert torch.allclose(final_output, expected_hidden_states, atol=1e-4), "Output does not match expected result"
131 changes: 101 additions & 30 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from typing import Callable, Optional

import torch
import torch.nn.functional as F
import torch_npu
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod


def group_topk(hidden_states: torch.Tensor,
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
Expand Down Expand Up @@ -140,37 +142,106 @@ def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
return hidden_states


def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
def fused_moe_torch(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
renormalize: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
expert_map: [num_experts]
"""
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
dtype = hidden_states.dtype

hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, global_num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)

if expert_map is not None:
selected_experts = expert_map[selected_experts]

final_hidden_states = None
for expert_idx in range(num_experts):
expert_w1 = w1[expert_idx]
expert_w2 = w2[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
x = F.linear(hidden_states, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states = final_hidden_states + current_hidden_states

return final_hidden_states.view(orig_shape) # type: ignore


def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
if envs.VLLM_TEST_ENABLE_EP:
return fused_moe_torch(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=top_k,
gating_output=router_logits,
global_num_experts=global_num_experts,
expert_map=expert_map,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = grouped_topk(
hidden_states=x,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k)

topk_weights, topk_ids = group_topk(
hidden_states=x,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k)


UnquantizedFusedMoEMethod.forward_oot = forward_oot
Loading