diff --git a/setup.py b/setup.py index b37ec2a7..635c3a4b 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ ['cuda/scatter.cpp', 'cuda/scatter_kernel.cu']) ] -__version__ = '1.3.0' +__version__ = '1.3.1' url = 'https://github.com/rusty1s/pytorch_scatter' install_requires = [] diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 562261ab..64cc7bfe 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -7,7 +7,7 @@ from .max import scatter_max from .min import scatter_min -__version__ = '1.3.0' +__version__ = '1.3.1' __all__ = [ 'scatter_add', diff --git a/torch_scatter/max.py b/torch_scatter/max.py index a894cf68..bf51f514 100644 --- a/torch_scatter/max.py +++ b/torch_scatter/max.py @@ -24,8 +24,11 @@ def backward(ctx, grad_out, grad_arg): grad_src = None if ctx.needs_input_grad[1]: - grad_src = grad_out.new_zeros(index.size()) - grad_src.scatter_(ctx.dim, arg.detach(), grad_out) + size = list(index.size()) + size[ctx.dim] += 1 + grad_src = grad_out.new_zeros(size) + grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out) + grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim)) return None, grad_src, None, None diff --git a/torch_scatter/min.py b/torch_scatter/min.py index 5c48edf8..227063e5 100644 --- a/torch_scatter/min.py +++ b/torch_scatter/min.py @@ -24,8 +24,11 @@ def backward(ctx, grad_out, grad_arg): grad_src = None if ctx.needs_input_grad[1]: - grad_src = grad_out.new_zeros(index.size()) - grad_src.scatter_(ctx.dim, arg.detach(), grad_out) + size = list(index.size()) + size[ctx.dim] += 1 + grad_src = grad_out.new_zeros(size) + grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out) + grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim)) return None, grad_src, None, None