diff --git a/SeqReverseSequence.lua b/SeqReverseSequence.lua index 2588ff7..842be61 100644 --- a/SeqReverseSequence.lua +++ b/SeqReverseSequence.lua @@ -17,7 +17,7 @@ end function SeqReverseSequence:reverseOutput(input) self.output:resizeAs(input) - self.outputIndices = self.outputIndices or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaTensor() or torch.LongTensor()) + self.outputIndices = self.outputIndices or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaTensor() or (torch.type(input) == 'torch.ClTensor') and torch.ClTensor() or torch.LongTensor()) self.outputIndices:resize(input:size()) local T = input:size(1) for x = 1, T do @@ -45,7 +45,7 @@ end function SeqReverseSequence:reverseGradOutput(gradOutput) self.gradInput:resizeAs(gradOutput) - self.gradIndices = self.gradIndices or ((torch.type(gradOutput) == 'torch.CudaTensor') and torch.CudaTensor() or torch.LongTensor()) + self.gradIndices = self.gradIndices or ((torch.type(gradOutput) == 'torch.CudaTensor') and torch.CudaTensor() or (torch.type(gradOutput) == 'torch.ClTensor') and torch.ClTensor() or torch.LongTensor()) self.gradIndices:resize(gradOutput:size()) local T = gradOutput:size(1) for x = 1, T do