From 873b9ec70e28572266bec0ac16e334f18274976d Mon Sep 17 00:00:00 2001 From: Jin-Hwa Kim Date: Wed, 23 Mar 2016 16:29:48 +0900 Subject: [PATCH] fix #170 using `mono` option to Lazy Dropout --- AbstractRecurrent.lua | 3 ++ Dropout.lua | 10 ++++- GRU.lua | 15 +++---- test/test.lua | 98 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 9 deletions(-) diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua index fc3690a..ba68f67 100644 --- a/AbstractRecurrent.lua +++ b/AbstractRecurrent.lua @@ -39,6 +39,9 @@ function AbstractRecurrent:maskZero(nInputDim) end function AbstractRecurrent:trimZero(nInputDim) + if torch.typename(self)=='nn.GRU' and self.p ~= 0 then + assert(self.mono, "TrimZero for BGRU needs `mono` option.") + end self.recurrentModule = nn.TrimZero(self.recurrentModule, nInputDim, true) self.sharedClones = {self.recurrentModule} return self diff --git a/Dropout.lua b/Dropout.lua index ee6a65a..ede1f37 100644 --- a/Dropout.lua +++ b/Dropout.lua @@ -9,12 +9,13 @@ ------------------------------------------------------------------------ local Dropout, Parent = nn.Dropout, nn.Module -function Dropout:__init(p,v1,inplace,lazy) +function Dropout:__init(p,v1,inplace,lazy,mono) Parent.__init(self) self.p = p or 0.5 self.train = true self.inplace = inplace self.lazy = lazy or false + self.mono = mono or false -- used by trimZero, single sample for a batch self.flag = true -- used by lazy noise -- version 2 scales output during training instead of evaluation self.v2 = not v1 @@ -33,13 +34,18 @@ function Dropout:updateOutput(input) if self.p > 0 then if self.train then if not self.lazy or self.flag then - self.noise:resizeAs(input) + local noiseSize = input:size() + if self.mono then noiseSize[1] = 1 end + self.noise:resize(noiseSize) self.noise:bernoulli(1-self.p) if self.v2 then self.noise:div(1-self.p) end self.flag = false end + if self.mono and self.noise:size(1) ~= input:size(1) then + self.noise = self.noise:expandAs(input) + end self.output:cmul(self.noise) elseif not self.v2 then self.output:mul(1-self.p) diff --git a/GRU.lua b/GRU.lua index 61735ee..e22c3d3 100644 --- a/GRU.lua +++ b/GRU.lua @@ -16,12 +16,13 @@ assert(not nn.GRU, "update nnx package : luarocks install nnx") local GRU, parent = torch.class('nn.GRU', 'nn.AbstractRecurrent') -function GRU:__init(inputSize, outputSize, rho, p) +function GRU:__init(inputSize, outputSize, rho, p, mono) parent.__init(self, rho or 9999) self.p = p or 0 if p and p ~= 0 then assert(nn.Dropout(p,false,false,true).lazy, 'only work with Lazy Dropout!') end + self.mono = mono or false -- used by trimZero self.inputSize = inputSize self.outputSize = outputSize -- build the model @@ -46,16 +47,16 @@ function GRU:buildModel() if self.p ~= 0 then self.i2g = nn.Sequential() :add(nn.ConcatTable() - :add(nn.Dropout(self.p,false,false,true)) - :add(nn.Dropout(self.p,false,false,true))) + :add(nn.Dropout(self.p,false,false,true,self.mono)) + :add(nn.Dropout(self.p,false,false,true,self.mono))) :add(nn.ParallelTable() :add(nn.Linear(self.inputSize, self.outputSize)) :add(nn.Linear(self.inputSize, self.outputSize))) :add(nn.JoinTable(2)) self.o2g = nn.Sequential() :add(nn.ConcatTable() - :add(nn.Dropout(self.p,false,false,true)) - :add(nn.Dropout(self.p,false,false,true))) + :add(nn.Dropout(self.p,false,false,true,self.mono)) + :add(nn.Dropout(self.p,false,false,true,self.mono))) :add(nn.ParallelTable() :add(nn.LinearNoBias(self.outputSize, self.outputSize)) :add(nn.LinearNoBias(self.outputSize, self.outputSize))) @@ -96,8 +97,8 @@ function GRU:buildModel() local t2 = nn.Sequential() t2:add(nn.NarrowTable(2,2)):add(nn.CMulTable()) if self.p ~= 0 then - t1:add(nn.Dropout(self.p,false,false,true)) - t2:add(nn.Dropout(self.p,false,false,true)) + t1:add(nn.Dropout(self.p,false,false,true,self.mono)) + t2:add(nn.Dropout(self.p,false,false,true,self.mono)) end t1:add(nn.Linear(self.inputSize, self.outputSize)) t2:add(nn.LinearNoBias(self.outputSize, self.outputSize)) diff --git a/test/test.lua b/test/test.lua index a21e68e..95afd3f 100644 --- a/test/test.lua +++ b/test/test.lua @@ -3826,6 +3826,56 @@ function rnntest.TrimZero() end end end + + -- check to have the same loss + rnn_size = 8 + vocabSize = 7 + word_embedding_size = 10 + + x = torch.Tensor{{{1,2,3},{0,4,5},{0,0,7}}, + {{1,2,3},{2,4,5},{0,0,7}}, + {{1,2,3},{2,4,5},{3,0,7}}} + t = torch.ceil(torch.rand(x:size(2))) + + rnns = {'FastLSTM','GRU'} + methods = {'maskZero', 'trimZero'} + loss = torch.Tensor(#rnns, #methods, 3) + + for ir,arch in pairs(rnns) do + local rnn = nn[arch](word_embedding_size, rnn_size) + local model = nn.Sequential() + :add(nn.LookupTableMaskZero(vocabSize, word_embedding_size)) + :add(nn.SplitTable(2)) + :add(nn.Sequencer(rnn)) + :add(nn.SelectTable(-1)) + :add(nn.Linear(rnn_size, 10)) + model:getParameters():uniform(-0.1, 0.1) + criterion = nn.CrossEntropyCriterion() + local models = {} + for j=1,#methods do + table.insert(models, model:clone()) + end + for im,method in pairs(methods) do + -- print('-- '..arch..' with '..method) + model = models[im] + rnn = model:get(3).module + rnn[method](rnn, 1) + sys.tic() + for i=1,loss:size(3) do + model:zeroGradParameters() + y = model:forward(x[i]) + loss[ir][im][i] = criterion:forward(y,t) + -- print('loss:', loss[ir][im][i]) + dy = criterion:backward(y,t) + model:backward(x[i], dy) + w,dw = model:parameters() + model:updateParameters(.5) + end + elapse = sys.toc() + -- print('elapse time:', elapse) + end + end + mytester:assertTensorEq(loss:select(2,1), loss:select(2,2), 0.0000001, "loss check") end function rnntest.AbstractRecurrent_maskZero() @@ -4192,6 +4242,54 @@ function rnntest.issue129() mytester:assertTensorEq(output, output2, 0.0002, "issue 129 err") end +function rnntest.issue170() + torch.manualSeed(123) + + rnn_size = 8 + vocabSize = 7 + word_embedding_size = 10 + rnn_dropout = .00000001 -- dropout ignores manualSeed() + mono = true + + x = torch.Tensor{{1,2,3},{0,4,5},{0,0,7}} + t = torch.ceil(torch.rand(x:size(2))) + + rnns = {'GRU'} + methods = {'maskZero', 'trimZero'} + loss = torch.Tensor(#rnns, #methods,1) + + for ir,arch in pairs(rnns) do + local rnn = nn[arch](word_embedding_size, rnn_size, nil, rnn_dropout, mono) + local model = nn.Sequential() + :add(nn.LookupTableMaskZero(vocabSize, word_embedding_size)) + :add(nn.SplitTable(2)) + :add(nn.Sequencer(rnn)) + :add(nn.SelectTable(-1)) + :add(nn.Linear(rnn_size, 10)) + model:getParameters():uniform(-0.1, 0.1) + criterion = nn.CrossEntropyCriterion() + local models = {} + for j=1,#methods do + table.insert(models, model:clone()) + end + for im,method in pairs(methods) do + model = models[im] + rnn = model:get(3).module + rnn[method](rnn, 1) + for i=1,loss:size(3) do + model:zeroGradParameters() + y = model:forward(x) + loss[ir][im][i] = criterion:forward(y,t) + dy = criterion:backward(y,t) + model:backward(x, dy) + w,dw = model:parameters() + model:updateParameters(.5) + end + end + end + mytester:assertTensorEq(loss:select(2,1), loss:select(2,2), 0.0000001, "loss check") +end + function rnntest.encoderdecoder() torch.manualSeed(123)