Skip to content

Commit

Permalink
fix Element-Research#170 using mono option to Lazy Dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
jnhwkim committed Mar 23, 2016
1 parent 164f237 commit 873b9ec
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 9 deletions.
3 changes: 3 additions & 0 deletions AbstractRecurrent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ function AbstractRecurrent:maskZero(nInputDim)
end

function AbstractRecurrent:trimZero(nInputDim)
if torch.typename(self)=='nn.GRU' and self.p ~= 0 then
assert(self.mono, "TrimZero for BGRU needs `mono` option.")
end
self.recurrentModule = nn.TrimZero(self.recurrentModule, nInputDim, true)
self.sharedClones = {self.recurrentModule}
return self
Expand Down
10 changes: 8 additions & 2 deletions Dropout.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
------------------------------------------------------------------------
local Dropout, Parent = nn.Dropout, nn.Module

function Dropout:__init(p,v1,inplace,lazy)
function Dropout:__init(p,v1,inplace,lazy,mono)
Parent.__init(self)
self.p = p or 0.5
self.train = true
self.inplace = inplace
self.lazy = lazy or false
self.mono = mono or false -- used by trimZero, single sample for a batch
self.flag = true -- used by lazy noise
-- version 2 scales output during training instead of evaluation
self.v2 = not v1
Expand All @@ -33,13 +34,18 @@ function Dropout:updateOutput(input)
if self.p > 0 then
if self.train then
if not self.lazy or self.flag then
self.noise:resizeAs(input)
local noiseSize = input:size()
if self.mono then noiseSize[1] = 1 end
self.noise:resize(noiseSize)
self.noise:bernoulli(1-self.p)
if self.v2 then
self.noise:div(1-self.p)
end
self.flag = false
end
if self.mono and self.noise:size(1) ~= input:size(1) then
self.noise = self.noise:expandAs(input)
end
self.output:cmul(self.noise)
elseif not self.v2 then
self.output:mul(1-self.p)
Expand Down
15 changes: 8 additions & 7 deletions GRU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
assert(not nn.GRU, "update nnx package : luarocks install nnx")
local GRU, parent = torch.class('nn.GRU', 'nn.AbstractRecurrent')

function GRU:__init(inputSize, outputSize, rho, p)
function GRU:__init(inputSize, outputSize, rho, p, mono)
parent.__init(self, rho or 9999)
self.p = p or 0
if p and p ~= 0 then
assert(nn.Dropout(p,false,false,true).lazy, 'only work with Lazy Dropout!')
end
self.mono = mono or false -- used by trimZero
self.inputSize = inputSize
self.outputSize = outputSize
-- build the model
Expand All @@ -46,16 +47,16 @@ function GRU:buildModel()
if self.p ~= 0 then
self.i2g = nn.Sequential()
:add(nn.ConcatTable()
:add(nn.Dropout(self.p,false,false,true))
:add(nn.Dropout(self.p,false,false,true)))
:add(nn.Dropout(self.p,false,false,true,self.mono))
:add(nn.Dropout(self.p,false,false,true,self.mono)))
:add(nn.ParallelTable()
:add(nn.Linear(self.inputSize, self.outputSize))
:add(nn.Linear(self.inputSize, self.outputSize)))
:add(nn.JoinTable(2))
self.o2g = nn.Sequential()
:add(nn.ConcatTable()
:add(nn.Dropout(self.p,false,false,true))
:add(nn.Dropout(self.p,false,false,true)))
:add(nn.Dropout(self.p,false,false,true,self.mono))
:add(nn.Dropout(self.p,false,false,true,self.mono)))
:add(nn.ParallelTable()
:add(nn.LinearNoBias(self.outputSize, self.outputSize))
:add(nn.LinearNoBias(self.outputSize, self.outputSize)))
Expand Down Expand Up @@ -96,8 +97,8 @@ function GRU:buildModel()
local t2 = nn.Sequential()
t2:add(nn.NarrowTable(2,2)):add(nn.CMulTable())
if self.p ~= 0 then
t1:add(nn.Dropout(self.p,false,false,true))
t2:add(nn.Dropout(self.p,false,false,true))
t1:add(nn.Dropout(self.p,false,false,true,self.mono))
t2:add(nn.Dropout(self.p,false,false,true,self.mono))
end
t1:add(nn.Linear(self.inputSize, self.outputSize))
t2:add(nn.LinearNoBias(self.outputSize, self.outputSize))
Expand Down
98 changes: 98 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3826,6 +3826,56 @@ function rnntest.TrimZero()
end
end
end

-- check to have the same loss
rnn_size = 8
vocabSize = 7
word_embedding_size = 10

x = torch.Tensor{{{1,2,3},{0,4,5},{0,0,7}},
{{1,2,3},{2,4,5},{0,0,7}},
{{1,2,3},{2,4,5},{3,0,7}}}
t = torch.ceil(torch.rand(x:size(2)))

rnns = {'FastLSTM','GRU'}
methods = {'maskZero', 'trimZero'}
loss = torch.Tensor(#rnns, #methods, 3)

for ir,arch in pairs(rnns) do
local rnn = nn[arch](word_embedding_size, rnn_size)
local model = nn.Sequential()
:add(nn.LookupTableMaskZero(vocabSize, word_embedding_size))
:add(nn.SplitTable(2))
:add(nn.Sequencer(rnn))
:add(nn.SelectTable(-1))
:add(nn.Linear(rnn_size, 10))
model:getParameters():uniform(-0.1, 0.1)
criterion = nn.CrossEntropyCriterion()
local models = {}
for j=1,#methods do
table.insert(models, model:clone())
end
for im,method in pairs(methods) do
-- print('-- '..arch..' with '..method)
model = models[im]
rnn = model:get(3).module
rnn[method](rnn, 1)
sys.tic()
for i=1,loss:size(3) do
model:zeroGradParameters()
y = model:forward(x[i])
loss[ir][im][i] = criterion:forward(y,t)
-- print('loss:', loss[ir][im][i])
dy = criterion:backward(y,t)
model:backward(x[i], dy)
w,dw = model:parameters()
model:updateParameters(.5)
end
elapse = sys.toc()
-- print('elapse time:', elapse)
end
end
mytester:assertTensorEq(loss:select(2,1), loss:select(2,2), 0.0000001, "loss check")
end

function rnntest.AbstractRecurrent_maskZero()
Expand Down Expand Up @@ -4192,6 +4242,54 @@ function rnntest.issue129()
mytester:assertTensorEq(output, output2, 0.0002, "issue 129 err")
end

function rnntest.issue170()
torch.manualSeed(123)

rnn_size = 8
vocabSize = 7
word_embedding_size = 10
rnn_dropout = .00000001 -- dropout ignores manualSeed()
mono = true

x = torch.Tensor{{1,2,3},{0,4,5},{0,0,7}}
t = torch.ceil(torch.rand(x:size(2)))

rnns = {'GRU'}
methods = {'maskZero', 'trimZero'}
loss = torch.Tensor(#rnns, #methods,1)

for ir,arch in pairs(rnns) do
local rnn = nn[arch](word_embedding_size, rnn_size, nil, rnn_dropout, mono)
local model = nn.Sequential()
:add(nn.LookupTableMaskZero(vocabSize, word_embedding_size))
:add(nn.SplitTable(2))
:add(nn.Sequencer(rnn))
:add(nn.SelectTable(-1))
:add(nn.Linear(rnn_size, 10))
model:getParameters():uniform(-0.1, 0.1)
criterion = nn.CrossEntropyCriterion()
local models = {}
for j=1,#methods do
table.insert(models, model:clone())
end
for im,method in pairs(methods) do
model = models[im]
rnn = model:get(3).module
rnn[method](rnn, 1)
for i=1,loss:size(3) do
model:zeroGradParameters()
y = model:forward(x)
loss[ir][im][i] = criterion:forward(y,t)
dy = criterion:backward(y,t)
model:backward(x, dy)
w,dw = model:parameters()
model:updateParameters(.5)
end
end
end
mytester:assertTensorEq(loss:select(2,1), loss:select(2,2), 0.0000001, "loss check")
end

function rnntest.encoderdecoder()
torch.manualSeed(123)

Expand Down

0 comments on commit 873b9ec

Please sign in to comment.