diff --git a/tests/ops/test_fused_moe.py b/tests/ops/test_fused_moe.py new file mode 100644 index 00000000..823e380e --- /dev/null +++ b/tests/ops/test_fused_moe.py @@ -0,0 +1,69 @@ +import pytest +import torch +from vllm_ascend.ops.fused_moe import fused_moe + + +def test_fused_moe(): + # Since we are using native PyTorch operations in the function, the most reliable ground truth + # for comparison is the manually computed output. By using hardcoded data, we can ensure + # that the function produces the expected results and validate its correctness against a known reference. + + # Step 1: Constructing inputs + hidden_states = torch.tensor([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]) + + # w1: [3, 4, 3] (num_experts=3, intermediate_size*2=4, hidden_size=3) + w1 = torch.tensor( + [ + [[1.0, 0.0, -1.0], [2.0, 1.0, 0.0], [1.0, 1.0, -1.0], [1.0, -1.0, 1.0]], + [[-1.0, 1.0, 1.0], [1.0, -1.0, 1.0], [2.0, -2.0, 2.0], [1.0, 0.0, -1.0]], + [[-2.0, -1.0, 1.0], [2.0, -1.0, 1.0], [-1.0, 2.0, 1.0], [1.0, 1.0, -1.0]], + ] + ) + + # w2: [3, 3, 2] (num_experts=3, hidden_size=3, intermediate_size=2) + w2 = torch.tensor( + [ + [[1.0, 0.5], [2.0, -1.0], [0.0, 1.0]], + [[1.0, 1.0], [-1.0, 1.0], [1.0, -0.0]], + [[-2.0, 1.0], [1.0, -1.0], [2.0, 1.0]], + ] + ) + + # gating_output: [2, 3] (num_tokens=2, num_experts=3) + gating_output = torch.tensor([[0.0, 0.5, 0.5], [0.5, 0.5, 0.0]]) + + topk = 2 + + global_num_experts = 3 + + # Only has the first two experts + expert_map = torch.tensor([0, 1, -1]) + + renormalize = False + + use_grouped_topk = False + + # Step 2: Expected output calculation + + # We use topk=2, which means we select the top 2 experts based on gating_output. + # For sample 1, gating_output = [0.1, 0.7, 0.2], topk_weights = [0.7, 0.2], selected experts = 1, 2 + # For sample 2, gating_output = [0.5, 0.4, 0.1], topk_weights = [0.5, 0.4], selected experts = 0, 1 + + # 1. Calculate linear transformation of hidden_states with w1[0] -> F.linear(hidden_states, w1[0]) + # 2. Apply gating function to get gate values -> F.silu(x[:, :intermediate_size]) + # 3. Apply second linear transformation with w2[0] -> F.linear(x, w2[0]) + # 4. Use the topk_weights for each sample and add the weighted outputs of experts 1 and 2 + + expected_hidden_states = torch.tensor([[4.6763, -7.3797, 6.0280], [7.1232, 0.6220, 6.1364]]) + + # Step 3: Running the fused_moe function + final_output = fused_moe( + hidden_states, w1, w2, gating_output, topk, global_num_experts, expert_map, renormalize, use_grouped_topk + ) + + # Step 4: Check the shape and values (this should match the expected result you computed manually) + assert ( + final_output.shape == hidden_states.shape + ), f"Expected shape {hidden_states.shape}, but got {final_output.shape}" + + assert torch.allclose(final_output, expected_hidden_states, atol=1e-4), "Output does not match expected result" diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index cbb86224..83f58928 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -18,12 +18,14 @@ from typing import Callable, Optional import torch +import torch.nn.functional as F import torch_npu +import vllm.envs as envs from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod -def group_topk(hidden_states: torch.Tensor, +def grouped_topk(hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, @@ -140,37 +142,106 @@ def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, return hidden_states -def forward_oot( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None +def fused_moe_torch( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, ) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + expert_map: [num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + dtype = hidden_states.dtype + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, global_num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + if expert_map is not None: + selected_experts = expert_map[selected_experts] + + final_hidden_states = None + for expert_idx in range(num_experts): + expert_w1 = w1[expert_idx] + expert_w2 = w2[expert_idx] + expert_mask = selected_experts == expert_idx + expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) + x = F.linear(hidden_states, expert_w1) + gate = F.silu(x[:, :intermediate_size]) + x = x[:, intermediate_size:] * gate + x = F.linear(x, expert_w2) + current_hidden_states = x * expert_weights + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states = final_hidden_states + current_hidden_states + + return final_hidden_states.view(orig_shape) # type: ignore + + +def forward_oot( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +): + if envs.VLLM_TEST_ENABLE_EP: + return fused_moe_torch( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, + renormalize=renormalize, + ) + else: + topk_weights, topk_ids = grouped_topk( + hidden_states=x, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k) - topk_weights, topk_ids = group_topk( - hidden_states=x, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k) UnquantizedFusedMoEMethod.forward_oot = forward_oot