Skip to content

Commit

Permalink
feat: use sgl-kernel by default
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Jan 23, 2025
1 parent 54bac8a commit 113c855
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 25 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ runtime_common = [
]
srt = [
"sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.2.post14", "torch", "vllm==0.6.4.post1",
"sgl-kernel>=0.0.2.post16", "torch", "vllm==0.6.4.post1",
"flashinfer==0.1.6"
]

Expand Down
13 changes: 10 additions & 3 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@
import torch.nn as nn
import torch.nn.functional as F

from sglang.srt.utils import is_flashinfer_available
from sglang.srt.utils import (
enable_use_sgl_kernel_first,
is_cuda_available,
is_flashinfer_available,
)

if is_flashinfer_available():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
if enable_use_sgl_kernel_first:
if is_cuda_available():
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
elif is_flashinfer_available():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul

from vllm.model_executor.custom_op import CustomOp

Expand Down
30 changes: 21 additions & 9 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,27 @@
import torch
import torch.nn as nn

from sglang.srt.utils import is_flashinfer_available

if is_flashinfer_available():
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
from sglang.srt.utils import (
enable_use_sgl_kernel_first,
is_cuda_available,
is_flashinfer_available,
)

if enable_use_sgl_kernel_first:
if is_cuda_available():
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
elif is_flashinfer_available():
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)

from vllm.model_executor.custom_op import CustomOp

Expand Down
24 changes: 17 additions & 7 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,27 @@
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import (
crash_on_warnings,
enable_use_sgl_kernel_first,
get_bool_env_var,
is_cuda_available,
is_flashinfer_available,
)

if is_flashinfer_available():
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
if enable_use_sgl_kernel_first:
if is_cuda_available():
from sgl_kernel import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
elif is_flashinfer_available():
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)


logger = logging.getLogger(__name__)
Expand Down
13 changes: 11 additions & 2 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@

is_hip_ = is_hip()

if is_flashinfer_available():
from flashinfer import bmm_fp8
from sglang.srt.utils import (
enable_use_sgl_kernel_first,
is_cuda_available,
is_flashinfer_available,
)

if enable_use_sgl_kernel_first:
if is_cuda_available():
from sgl_kernel import bmm_fp8
elif is_flashinfer_available():
from flashinfer import bmm_fp8


class DeepseekV2MLP(nn.Module):
Expand Down
13 changes: 10 additions & 3 deletions python/sglang/srt/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,17 @@
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_flashinfer_available
from sglang.srt.utils import (
enable_use_sgl_kernel_first,
is_cuda_available,
is_flashinfer_available,
)

if is_flashinfer_available():
from flashinfer import bmm_fp8
if enable_use_sgl_kernel_first:
if is_cuda_available():
from sgl_kernel import bmm_fp8
elif is_flashinfer_available():
from flashinfer import bmm_fp8


class MiniCPM3MLP(nn.Module):
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
show_time_cost = False
time_infos = {}

enable_use_sgl_kernel_first = bool(int(os.getenv("ENABLE_USE_SGL_KERNEL_FIRST", "1")))


def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
Expand Down

0 comments on commit 113c855

Please sign in to comment.