Skip to content

Commit

Permalink
fix backward for max min
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 5, 2019
1 parent e6821e3 commit 0a9f541
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
]

__version__ = '1.3.0'
__version__ = '1.3.1'
url = 'https://github.com/rusty1s/pytorch_scatter'

install_requires = []
Expand Down
2 changes: 1 addition & 1 deletion torch_scatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .max import scatter_max
from .min import scatter_min

__version__ = '1.3.0'
__version__ = '1.3.1'

__all__ = [
'scatter_add',
Expand Down
7 changes: 5 additions & 2 deletions torch_scatter/max.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ def backward(ctx, grad_out, grad_arg):

grad_src = None
if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size())
grad_src.scatter_(ctx.dim, arg.detach(), grad_out)
size = list(index.size())
size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))

return None, grad_src, None, None

Expand Down
7 changes: 5 additions & 2 deletions torch_scatter/min.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ def backward(ctx, grad_out, grad_arg):

grad_src = None
if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size())
grad_src.scatter_(ctx.dim, arg.detach(), grad_out)
size = list(index.size())
size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))

return None, grad_src, None, None

Expand Down

0 comments on commit 0a9f541

Please sign in to comment.