Skip to content

Commit

Permalink
Added Open Cl Support
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad Rashid committed Aug 22, 2016
1 parent 88cbe99 commit b6aa780
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions SeqReverseSequence.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end

function SeqReverseSequence:reverseOutput(input)
self.output:resizeAs(input)
self.outputIndices = self.outputIndices or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaTensor() or torch.LongTensor())
self.outputIndices = self.outputIndices or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaTensor() or (torch.type(input) == 'torch.ClTensor') and torch.ClTensor() or torch.LongTensor())
self.outputIndices:resize(input:size())
local T = input:size(1)
for x = 1, T do
Expand Down Expand Up @@ -45,7 +45,7 @@ end

function SeqReverseSequence:reverseGradOutput(gradOutput)
self.gradInput:resizeAs(gradOutput)
self.gradIndices = self.gradIndices or ((torch.type(gradOutput) == 'torch.CudaTensor') and torch.CudaTensor() or torch.LongTensor())
self.gradIndices = self.gradIndices or ((torch.type(gradOutput) == 'torch.CudaTensor') and torch.CudaTensor() or (torch.type(gradOutput) == 'torch.ClTensor') and torch.ClTensor() or torch.LongTensor())
self.gradIndices:resize(gradOutput:size())
local T = gradOutput:size(1)
for x = 1, T do
Expand Down

0 comments on commit b6aa780

Please sign in to comment.