From 76093871655f03dcdcdbb49815e0cdc145e65cc3 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 22 Oct 2022 13:21:22 +0000 Subject: [PATCH 1/5] softmax --- pyg_lib/ops/__init__.py | 3 +++ pyg_lib/ops/softmax.py | 5 +++++ test/ops/test_softmax.py | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+) create mode 100644 pyg_lib/ops/softmax.py create mode 100644 test/ops/test_softmax.py diff --git a/pyg_lib/ops/__init__.py b/pyg_lib/ops/__init__.py index ef35d1c1b..99a36a40f 100644 --- a/pyg_lib/ops/__init__.py +++ b/pyg_lib/ops/__init__.py @@ -3,6 +3,8 @@ import torch from torch import Tensor +from .softmax import softmax + def grouped_matmul(inputs: List[Tensor], others: List[Tensor]) -> List[Tensor]: r"""Performs dense-dense matrix multiplication according to groups, @@ -72,4 +74,5 @@ def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor) -> Tensor: __all__ = [ 'grouped_matmul', 'segment_matmul', + 'softmax', ] diff --git a/pyg_lib/ops/softmax.py b/pyg_lib/ops/softmax.py new file mode 100644 index 000000000..81bb667b0 --- /dev/null +++ b/pyg_lib/ops/softmax.py @@ -0,0 +1,5 @@ +from torch import Tensor + + +def softmax(inputs: Tensor, ptr: Tensor) -> Tensor: + pass diff --git a/test/ops/test_softmax.py b/test/ops/test_softmax.py new file mode 100644 index 000000000..14ea448c7 --- /dev/null +++ b/test/ops/test_softmax.py @@ -0,0 +1,18 @@ +import torch + +from pyg_lib.ops import softmax +from pyg_lib.testing import onlyCUDA + + +@onlyCUDA +def test_softmax(): + inputs = torch.randn(8, 16, device='cuda') + ptr = torch.tensor([0, 3, 8], device='cuda') + + out = softmax(inputs, ptr) + print(out[0:3]) + print(torch.softmax(inputs[0:3], dim=0)) + assert torch.allclose(out[0:3], torch.softmax(inputs[0:3], dim=0)) + print(out[3:8]) + print(torch.softmax(inputs[3:8], dim=0)) + assert torch.allclose(out[3:8], torch.softmax(inputs[3:8], dim=0)) From 66d7cf5acd9bc50b1018634e00b9c0d0790e9d44 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 22 Oct 2022 13:25:14 +0000 Subject: [PATCH 2/5] update --- pyg_lib/ops/softmax.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/pyg_lib/ops/softmax.py b/pyg_lib/ops/softmax.py index 81bb667b0..4203ef1e3 100644 --- a/pyg_lib/ops/softmax.py +++ b/pyg_lib/ops/softmax.py @@ -1,5 +1,37 @@ +from typing import Optional + +import torch +import triton +# import triton.language as tl from torch import Tensor -def softmax(inputs: Tensor, ptr: Tensor) -> Tensor: +@triton.jit +def softmax_kernel(x_ptr, ptr, out_ptr, numel, **meta): + # pid = tl.program_id(axis=0) + # block_start = pid * meta['BLOCK_SIZE'] + + # offsets = block_start + tl.arange(0, meta['BLOCK_SIZE']) + # mask = offsets < numel + + # x = tl.load(x_ptr + offsets, mask=mask) + # y = tl.load(y_ptr + offsets, mask=mask) + + # output = x + y + + # tl.store(out_ptr + offsets, output, mask=mask) pass + + +def softmax( + inputs: Tensor, + ptr: Tensor, + out: Optional[Tensor] = None, +) -> Tensor: + if out is None: + out = torch.empty_like(inputs) + assert inputs.is_cuda and ptr.is_cuda and out.is_cuda + + grid = lambda meta: (triton.cdiv(inputs.numel(), meta['BLOCK_SIZE']), ) + softmax_kernel[grid](inputs, ptr, out, inputs.numel(), BLOCK_SIZE=1024) + return out From d8cb55cdca3f0926cad5583c1b64fdb0a362883f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 24 Oct 2022 11:57:32 +0000 Subject: [PATCH 3/5] update --- pyg_lib/ops/softmax.py | 59 ++++++++++++++++++++++++++++++---------- pyg_lib/testing.py | 2 +- setup.cfg | 2 +- test/ops/test_softmax.py | 29 ++++++++++++++++---- test/test_triton.py | 4 +-- 5 files changed, 71 insertions(+), 25 deletions(-) diff --git a/pyg_lib/ops/softmax.py b/pyg_lib/ops/softmax.py index 4203ef1e3..e6627cd9f 100644 --- a/pyg_lib/ops/softmax.py +++ b/pyg_lib/ops/softmax.py @@ -1,26 +1,42 @@ from typing import Optional import torch -import triton -# import triton.language as tl from torch import Tensor +from pyg_lib._triton import tl, triton + @triton.jit -def softmax_kernel(x_ptr, ptr, out_ptr, numel, **meta): - # pid = tl.program_id(axis=0) - # block_start = pid * meta['BLOCK_SIZE'] +def softmax_kernel(x_ptr, ptr_ptr, out_ptr, M, N, num_segments, **meta): + ptr_block_start = tl.program_id(axis=0) * meta['SEGMENT_BLOCK_SIZE'] + ptr_offset = ptr_block_start + tl.arange(0, meta['SEGMENT_BLOCK_SIZE']) + ptr_mask = ptr_offset < num_segments + + ptr1 = tl.load(ptr_ptr + ptr_offset, mask=ptr_mask) + ptr2 = tl.load(ptr_ptr + ptr_offset + 1, mask=ptr_mask) + count = ptr2 - ptr1 + # max_count = tl.max(ptr2 - ptr1, axis=0) + # max_count = tl.multiple_of(max_count, 8) + max_count = 10 # TODO + M_offset = tl.arange(0, max_count) + + N_block_start = tl.program_id(axis=1) * meta['BLOCK_SIZE_N'] + N_offset = N_block_start + tl.arange(0, meta['BLOCK_SIZE_N']) - # offsets = block_start + tl.arange(0, meta['BLOCK_SIZE']) - # mask = offsets < numel + # M_mask = M_offset[None, :] < count[:, None] - # x = tl.load(x_ptr + offsets, mask=mask) - # y = tl.load(y_ptr + offsets, mask=mask) + x_offset = (N * ptr1[:, None, None] + N * M_offset[None, :, None] + + N_offset[None, None, :]) + x_mask = ((ptr1[:, None, None] < M) & + (M_offset[None, :, None] < count[:, None, None]) & + (N_offset[None, None, :] < N)) - # output = x + y + x = tl.load(x_ptr + x_offset, mask=x_mask, other=float('-inf')) + x = x - tl.max(x, axis=1)[:, None, :] + x = tl.exp(x) + out = x / tl.sum(x, axis=1)[:, None, :] - # tl.store(out_ptr + offsets, output, mask=mask) - pass + tl.store(out_ptr + x_offset, out, mask=x_mask) def softmax( @@ -30,8 +46,21 @@ def softmax( ) -> Tensor: if out is None: out = torch.empty_like(inputs) - assert inputs.is_cuda and ptr.is_cuda and out.is_cuda + out.resize_(inputs.size()) + + out.fill_(-1) + + assert inputs.dim() == 2 and inputs.is_cuda and inputs.is_contiguous() + assert ptr.dim() == 1 and ptr.is_cuda and ptr.is_contiguous() + assert out.dim() == 2 and out.is_cuda and out.is_contiguous() + + (M, N), num_segments = inputs.size(), ptr.numel() - 1 + print('M', M, 'N', N, 'num_segments', num_segments) - grid = lambda meta: (triton.cdiv(inputs.numel(), meta['BLOCK_SIZE']), ) - softmax_kernel[grid](inputs, ptr, out, inputs.numel(), BLOCK_SIZE=1024) + grid = lambda meta: ( + triton.cdiv(num_segments, meta['SEGMENT_BLOCK_SIZE']), + triton.cdiv(N, meta['BLOCK_SIZE_N']), + ) + softmax_kernel[grid](inputs, ptr, out, M, N, num_segments, + SEGMENT_BLOCK_SIZE=2, BLOCK_SIZE_N=10) return out diff --git a/pyg_lib/testing.py b/pyg_lib/testing.py index b95ae6b78..66e673079 100644 --- a/pyg_lib/testing.py +++ b/pyg_lib/testing.py @@ -27,7 +27,7 @@ def onlyCUDA(func: Callable) -> Callable: )(func) -def withTriton(func: Callable) -> Callable: +def onlyTriton(func: Callable) -> Callable: import pytest return pytest.mark.skipif( diff --git a/setup.cfg b/setup.cfg index ac8f38070..11d518c9b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,7 +19,7 @@ test=pytest addopts=--capture=no [flake8] -ignore=E731 +ignore=E731,W504 [isort] multi_line_output=3 diff --git a/test/ops/test_softmax.py b/test/ops/test_softmax.py index 14ea448c7..7f3a53761 100644 --- a/test/ops/test_softmax.py +++ b/test/ops/test_softmax.py @@ -1,18 +1,35 @@ import torch from pyg_lib.ops import softmax -from pyg_lib.testing import onlyCUDA +from pyg_lib.testing import onlyCUDA, onlyTriton @onlyCUDA +@onlyTriton def test_softmax(): - inputs = torch.randn(8, 16, device='cuda') + inputs = torch.randn(8, 5, device='cuda') ptr = torch.tensor([0, 3, 8], device='cuda') + print() + a = inputs[0:3] + a = (a - a.max(dim=0, keepdim=True)[0]).exp() + a = a / a.sum(dim=0, keepdim=True) + print(a) + b = inputs[3:8] + b = (b - b.max(dim=0, keepdim=True)[0]).exp() + b = b / b.sum(dim=0, keepdim=True) + print(b) + out = softmax(inputs, ptr) - print(out[0:3]) + print() + # print((inputs[0:3] - inputs[0:3].max(dim=0, keepdim=True)[0]).exp()) + # print((inputs[3:8] - inputs[3:8].max(dim=0, keepdim=True)[0]).exp()) + print() + print(out) + print() + # print(out[0:3]) print(torch.softmax(inputs[0:3], dim=0)) - assert torch.allclose(out[0:3], torch.softmax(inputs[0:3], dim=0)) - print(out[3:8]) + # assert torch.allclose(out[0:3], torch.softmax(inputs[0:3], dim=0)) + # print(out[3:8]) print(torch.softmax(inputs[3:8], dim=0)) - assert torch.allclose(out[3:8], torch.softmax(inputs[3:8], dim=0)) + # assert torch.allclose(out[3:8], torch.softmax(inputs[3:8], dim=0)) diff --git a/test/test_triton.py b/test/test_triton.py index 29397fe02..1615ba330 100644 --- a/test/test_triton.py +++ b/test/test_triton.py @@ -2,7 +2,7 @@ from torch import Tensor from pyg_lib._triton import tl, triton -from pyg_lib.testing import onlyCUDA, withTriton +from pyg_lib.testing import onlyCUDA, onlyTriton @triton.jit @@ -30,7 +30,7 @@ def add(x: Tensor, y: Tensor) -> Tensor: @onlyCUDA -@withTriton +@onlyTriton def test_triton(): x = torch.rand(100, device='cuda') y = torch.rand(100, device='cuda') From 14aed9fc1860c34822839f7714175743f1eeeced Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 24 Oct 2022 12:52:18 +0000 Subject: [PATCH 4/5] changelog --- CHANGELOG.md | 1 + pyg_lib/ops/softmax.py | 6 +++--- test/ops/test_softmax.py | 37 ++++++++++++++++++++++--------------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a59a406c..68a8432eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added +- Added fused softmax kernel implementation ([#135](https://github.com/pyg-team/pyg-lib/pull/135)) - Added `triton` dependency ([#133](https://github.com/pyg-team/pyg-lib/pull/133), [#134](https://github.com/pyg-team/pyg-lib/pull/134)) - Enable `pytest` testing ([#132](https://github.com/pyg-team/pyg-lib/pull/132)) - Added C++-based autograd and TorchScript support for `segment_matmul` ([#120](https://github.com/pyg-team/pyg-lib/pull/120), [#122](https://github.com/pyg-team/pyg-lib/pull/122)) diff --git a/pyg_lib/ops/softmax.py b/pyg_lib/ops/softmax.py index e6627cd9f..050198de3 100644 --- a/pyg_lib/ops/softmax.py +++ b/pyg_lib/ops/softmax.py @@ -12,8 +12,8 @@ def softmax_kernel(x_ptr, ptr_ptr, out_ptr, M, N, num_segments, **meta): ptr_offset = ptr_block_start + tl.arange(0, meta['SEGMENT_BLOCK_SIZE']) ptr_mask = ptr_offset < num_segments - ptr1 = tl.load(ptr_ptr + ptr_offset, mask=ptr_mask) - ptr2 = tl.load(ptr_ptr + ptr_offset + 1, mask=ptr_mask) + ptr1 = tl.load(ptr_ptr + ptr_offset, mask=ptr_mask, other=1000000) + ptr2 = tl.load(ptr_ptr + ptr_offset + 1, mask=ptr_mask, other=1000000) count = ptr2 - ptr1 # max_count = tl.max(ptr2 - ptr1, axis=0) # max_count = tl.multiple_of(max_count, 8) @@ -62,5 +62,5 @@ def softmax( triton.cdiv(N, meta['BLOCK_SIZE_N']), ) softmax_kernel[grid](inputs, ptr, out, M, N, num_segments, - SEGMENT_BLOCK_SIZE=2, BLOCK_SIZE_N=10) + SEGMENT_BLOCK_SIZE=1, BLOCK_SIZE_N=1) return out diff --git a/test/ops/test_softmax.py b/test/ops/test_softmax.py index 7f3a53761..dcc16a510 100644 --- a/test/ops/test_softmax.py +++ b/test/ops/test_softmax.py @@ -7,29 +7,36 @@ @onlyCUDA @onlyTriton def test_softmax(): - inputs = torch.randn(8, 5, device='cuda') - ptr = torch.tensor([0, 3, 8], device='cuda') + inputs = torch.randn(16, 5, device='cuda') + ptr = torch.tensor([0, 3, 8, 11, 16], device='cuda') - print() - a = inputs[0:3] - a = (a - a.max(dim=0, keepdim=True)[0]).exp() - a = a / a.sum(dim=0, keepdim=True) - print(a) - b = inputs[3:8] - b = (b - b.max(dim=0, keepdim=True)[0]).exp() - b = b / b.sum(dim=0, keepdim=True) - print(b) + # print() + # a = inputs[0:3] + # a = (a - a.max(dim=0, keepdim=True)[0]).exp() + # a = a / a.sum(dim=0, keepdim=True) + # print(a) + # b = inputs[3:8] + # b = (b - b.max(dim=0, keepdim=True)[0]).exp() + # b = b / b.sum(dim=0, keepdim=True) + # print(b) out = softmax(inputs, ptr) print() - # print((inputs[0:3] - inputs[0:3].max(dim=0, keepdim=True)[0]).exp()) - # print((inputs[3:8] - inputs[3:8].max(dim=0, keepdim=True)[0]).exp()) + a = (inputs[0:3] - inputs[0:3].max(dim=0, keepdim=True)[0]).exp() + # print(a + 2) + print(a.sum(dim=0, keepdim=True)) + a = (inputs[3:8] - inputs[3:8].max(dim=0, keepdim=True)[0]).exp() + print(a.sum(dim=0, keepdim=True)) + a = (inputs[8:11] - inputs[8:11].max(dim=0, keepdim=True)[0]).exp() + print(a.sum(dim=0, keepdim=True)) + a = (inputs[11:16] - inputs[11:16].max(dim=0, keepdim=True)[0]).exp() + print(a.sum(dim=0, keepdim=True)) print() print(out) print() # print(out[0:3]) print(torch.softmax(inputs[0:3], dim=0)) - # assert torch.allclose(out[0:3], torch.softmax(inputs[0:3], dim=0)) - # print(out[3:8]) print(torch.softmax(inputs[3:8], dim=0)) + print(torch.softmax(inputs[8:11], dim=0)) + print(torch.softmax(inputs[11:16], dim=0)) # assert torch.allclose(out[3:8], torch.softmax(inputs[3:8], dim=0)) From 2f0d616c6c562093131e446664e875955ec62cd3 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 24 Oct 2022 12:53:37 +0000 Subject: [PATCH 5/5] update --- pyg_lib/ops/softmax.py | 5 +---- test/ops/test_softmax.py | 35 ++++------------------------------- 2 files changed, 5 insertions(+), 35 deletions(-) diff --git a/pyg_lib/ops/softmax.py b/pyg_lib/ops/softmax.py index 050198de3..bf111a91d 100644 --- a/pyg_lib/ops/softmax.py +++ b/pyg_lib/ops/softmax.py @@ -23,8 +23,6 @@ def softmax_kernel(x_ptr, ptr_ptr, out_ptr, M, N, num_segments, **meta): N_block_start = tl.program_id(axis=1) * meta['BLOCK_SIZE_N'] N_offset = N_block_start + tl.arange(0, meta['BLOCK_SIZE_N']) - # M_mask = M_offset[None, :] < count[:, None] - x_offset = (N * ptr1[:, None, None] + N * M_offset[None, :, None] + N_offset[None, None, :]) x_mask = ((ptr1[:, None, None] < M) & @@ -55,12 +53,11 @@ def softmax( assert out.dim() == 2 and out.is_cuda and out.is_contiguous() (M, N), num_segments = inputs.size(), ptr.numel() - 1 - print('M', M, 'N', N, 'num_segments', num_segments) grid = lambda meta: ( triton.cdiv(num_segments, meta['SEGMENT_BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE_N']), ) softmax_kernel[grid](inputs, ptr, out, M, N, num_segments, - SEGMENT_BLOCK_SIZE=1, BLOCK_SIZE_N=1) + SEGMENT_BLOCK_SIZE=8, BLOCK_SIZE_N=128) return out diff --git a/test/ops/test_softmax.py b/test/ops/test_softmax.py index dcc16a510..854e3d8c6 100644 --- a/test/ops/test_softmax.py +++ b/test/ops/test_softmax.py @@ -7,36 +7,9 @@ @onlyCUDA @onlyTriton def test_softmax(): - inputs = torch.randn(16, 5, device='cuda') - ptr = torch.tensor([0, 3, 8, 11, 16], device='cuda') - - # print() - # a = inputs[0:3] - # a = (a - a.max(dim=0, keepdim=True)[0]).exp() - # a = a / a.sum(dim=0, keepdim=True) - # print(a) - # b = inputs[3:8] - # b = (b - b.max(dim=0, keepdim=True)[0]).exp() - # b = b / b.sum(dim=0, keepdim=True) - # print(b) + inputs = torch.randn(8, 5, device='cuda') + ptr = torch.tensor([0, 3, 8], device='cuda') out = softmax(inputs, ptr) - print() - a = (inputs[0:3] - inputs[0:3].max(dim=0, keepdim=True)[0]).exp() - # print(a + 2) - print(a.sum(dim=0, keepdim=True)) - a = (inputs[3:8] - inputs[3:8].max(dim=0, keepdim=True)[0]).exp() - print(a.sum(dim=0, keepdim=True)) - a = (inputs[8:11] - inputs[8:11].max(dim=0, keepdim=True)[0]).exp() - print(a.sum(dim=0, keepdim=True)) - a = (inputs[11:16] - inputs[11:16].max(dim=0, keepdim=True)[0]).exp() - print(a.sum(dim=0, keepdim=True)) - print() - print(out) - print() - # print(out[0:3]) - print(torch.softmax(inputs[0:3], dim=0)) - print(torch.softmax(inputs[3:8], dim=0)) - print(torch.softmax(inputs[8:11], dim=0)) - print(torch.softmax(inputs[11:16], dim=0)) - # assert torch.allclose(out[3:8], torch.softmax(inputs[3:8], dim=0)) + assert torch.allclose(out[0:3], torch.softmax(inputs[3:8], dim=0)) + assert torch.allclose(out[3:8], torch.softmax(inputs[3:8], dim=0))