Skip to content

Commit

Permalink
Mask changed to cuda byte tensor for cutorch api
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren committed Aug 9, 2016
1 parent ce4f0e7 commit bbec157
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion MaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function MaskZero:updateOutput(input)
local vectorDim = rmi:dim()
self._zeroMask = self._zeroMask or rmi.new()
self._zeroMask:norm(rmi, 2, vectorDim)
self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaTensor() or torch.ByteTensor())
self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor())
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0)

-- forward through decorated module
Expand Down

0 comments on commit bbec157

Please sign in to comment.