Skip to content

[WIP] Fused softmax kernel via Triton #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added fused softmax kernel implementation ([#135](https://github.com/pyg-team/pyg-lib/pull/135))
- Added PyTorch 1.13 support ([#145](https://github.com/pyg-team/pyg-lib/pull/145))
- Added native PyTorch support for `grouped_matmul` ([#137](https://github.com/pyg-team/pyg-lib/pull/137))
- Added `fused_scatter_reduce` operation for multiple reductions ([#141](https://github.com/pyg-team/pyg-lib/pull/141), [#142](https://github.com/pyg-team/pyg-lib/pull/142))
Expand Down
2 changes: 2 additions & 0 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor

from .softmax import softmax
from .scatter_reduce import fused_scatter_reduce


Expand Down Expand Up @@ -83,5 +84,6 @@ def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor) -> Tensor:
__all__ = [
'grouped_matmul',
'segment_matmul',
'softmax',
'fused_scatter_reduce',
]
63 changes: 63 additions & 0 deletions pyg_lib/ops/softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Optional

import torch
from torch import Tensor

from pyg_lib._triton import tl, triton


@triton.jit
def softmax_kernel(x_ptr, ptr_ptr, out_ptr, M, N, num_segments, **meta):
ptr_block_start = tl.program_id(axis=0) * meta['SEGMENT_BLOCK_SIZE']
ptr_offset = ptr_block_start + tl.arange(0, meta['SEGMENT_BLOCK_SIZE'])
ptr_mask = ptr_offset < num_segments

ptr1 = tl.load(ptr_ptr + ptr_offset, mask=ptr_mask, other=1000000)
ptr2 = tl.load(ptr_ptr + ptr_offset + 1, mask=ptr_mask, other=1000000)
count = ptr2 - ptr1
# max_count = tl.max(ptr2 - ptr1, axis=0)
# max_count = tl.multiple_of(max_count, 8)
max_count = 10 # TODO
M_offset = tl.arange(0, max_count)

N_block_start = tl.program_id(axis=1) * meta['BLOCK_SIZE_N']
N_offset = N_block_start + tl.arange(0, meta['BLOCK_SIZE_N'])

x_offset = (N * ptr1[:, None, None] + N * M_offset[None, :, None] +
N_offset[None, None, :])
x_mask = ((ptr1[:, None, None] < M) &
(M_offset[None, :, None] < count[:, None, None]) &
(N_offset[None, None, :] < N))

x = tl.load(x_ptr + x_offset, mask=x_mask, other=float('-inf'))
x = x - tl.max(x, axis=1)[:, None, :]
x = tl.exp(x)
out = x / tl.sum(x, axis=1)[:, None, :]

tl.store(out_ptr + x_offset, out, mask=x_mask)


def softmax(
inputs: Tensor,
ptr: Tensor,
out: Optional[Tensor] = None,
) -> Tensor:
if out is None:
out = torch.empty_like(inputs)
out.resize_(inputs.size())

out.fill_(-1)

assert inputs.dim() == 2 and inputs.is_cuda and inputs.is_contiguous()
assert ptr.dim() == 1 and ptr.is_cuda and ptr.is_contiguous()
assert out.dim() == 2 and out.is_cuda and out.is_contiguous()

(M, N), num_segments = inputs.size(), ptr.numel() - 1

grid = lambda meta: (
triton.cdiv(num_segments, meta['SEGMENT_BLOCK_SIZE']),
triton.cdiv(N, meta['BLOCK_SIZE_N']),
)
softmax_kernel[grid](inputs, ptr, out, M, N, num_segments,
SEGMENT_BLOCK_SIZE=8, BLOCK_SIZE_N=128)
return out
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ test=pytest
addopts=--capture=no

[flake8]
ignore=E731
ignore=E731,W504

[isort]
multi_line_output=3
Expand Down
15 changes: 15 additions & 0 deletions test/ops/test_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

from pyg_lib.ops import softmax
from pyg_lib.testing import onlyCUDA, onlyTriton


@onlyCUDA
@onlyTriton
def test_softmax():
inputs = torch.randn(8, 5, device='cuda')
ptr = torch.tensor([0, 3, 8], device='cuda')

out = softmax(inputs, ptr)
assert torch.allclose(out[0:3], torch.softmax(inputs[3:8], dim=0))
assert torch.allclose(out[3:8], torch.softmax(inputs[3:8], dim=0))