Skip to content

Commit

Permalink
NCEModule++
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed May 14, 2016
1 parent 3c87bc7 commit 990cf07
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 98 deletions.
8 changes: 8 additions & 0 deletions LookupTableMaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ end

function LookupTableMaskZero:updateOutput(input)
self.weight[1]:zero()
if self.__input and (torch.type(self.__input) ~= torch.type(input)) then
self.__input = nil -- fixes old casting bug
end
self.__input = self.__input or input.new()
self.__input:resizeAs(input):add(input, 1)
return parent.updateOutput(self, self.__input)
Expand All @@ -14,3 +17,8 @@ end
function LookupTableMaskZero:accGradParameters(input, gradOutput, scale)
parent.accGradParameters(self, self.__input, gradOutput, scale)
end

function LookupTableMaskZero:type(type, cache)
self.__input = nil
return parent.type(self, type, cache)
end
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Modules that `forward` entire sequences through a decorated `AbstractRecurrent`

Miscellaneous modules and criterions :
* [MaskZero](#rnn.MaskZero) : zeroes the `output` and `gradOutput` rows of the decorated module for commensurate `input` rows which are tensors of zeros;
* [TrimZero](#rnn.TrimZero) : is more computationally efficient than `MaskZero` when input length is variable to avoid calculating zero vectors while doing forward/backward;
* [TrimZero](#rnn.TrimZero) : same behavior as `MaskZero`, but more efficient when `input` contains lots zero-masked rows;
* [LookupTableMaskZero](#rnn.LookupTableMaskZero) : extends `nn.LookupTable` to support zero indexes for padding. Zero indexes are forwarded as tensors of zeros;
* [MaskZeroCriterion](#rnn.MaskZeroCriterion) : zeros the `gradInput` and `err` rows of the decorated criterion for commensurate `input` rows which are tensors of zeros;
* [SeqReverseSequence](#rnn.SeqReverseSequence) : reverses an input sequence on a specific dimension;
Expand Down Expand Up @@ -941,13 +941,18 @@ This decorator makes it possible to pad sequences with different lengths in the

<a name='rnn.TrimZero'></a>
## TrimZero ##

WARNING : only use this module if your input contains lots of zeros.
In almost all cases, [`MaskZero`](#rnn.MaskZero) will be faster, especially with CUDA.

The usage is the same with `MaskZero`.

```lua
mz = nn.TrimZero(module, nInputDim)
```

The only difference from `MaskZero` is that it reduces computational costs by varying a batch size, if any, for the case that varying lengths are provided in the input. Notice that when the lengths are consistent, `MaskZero` will be faster, because `TrimZero` has an operational cost.
The only difference from `MaskZero` is that it reduces computational costs by varying a batch size, if any, for the case that varying lengths are provided in the input.
Notice that when the lengths are consistent, `MaskZero` will be faster, because `TrimZero` has an operational cost.

In short, the result is the same with `MaskZero`'s, however, `TrimZero` is faster than `MaskZero` only when sentence lengths is costly vary.

Expand Down
190 changes: 112 additions & 78 deletions examples/noise-contrastive-estimate.lua
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
require 'paths'
require 'rnn'
require 'nngraph'
local dl = require 'dataload'
assert(nn.NCEModule, "please update dpnn")

version = 3
assert(nn.NCEModule and nn.NCEModule.version and nn.NCEModule.version > 3, "update dpnn : luarocks install dpnn")

--[[ command line arguments ]]--
cmd = torch.CmdLine()
cmd:text()
cmd:text('Train a Language Model using stacked LSTM on Google Billion Words dataset')
cmd:text('Example:')
cmd:text('th recurrent-language-model.lua --cuda --device 2 --progress --cutoff 4 --seqlen 10')
cmd:text("th noise-contrastive-estimate.lua --progress --earlystop 50 --cuda --device 2 --seqlen 20 --hiddensize '{200,200}' --batchsize 20 --startlr 1 --uniform 0.1 --cutoff 5 --schedule '{[5]=0.5,[6]=0.25,[7]=0.125,[8]=0.0625,[9]=0.03125,[10]=0.015625,[11]=0.0078125,[12]=0.00390625}'")
cmd:text("th examples/noise-contrastive-estimate.lua --cuda --trainsize 400000 --validsize 40000 --cutoff 10 --batchsize 128 --seqlen 100 --hiddensize '{250,250}' --progress --device 2")
cmd:text("th scripts/evaluate-rnnlm.lua --xplogpath /data/save/rnnlm/ptb:atlas:1458081269:1.t7 --cuda")
cmd:text('Options:')
-- training
Expand All @@ -31,6 +30,8 @@ cmd:option('--progress', false, 'print progress bar')
cmd:option('--silent', false, 'don\'t print anything to stdout')
cmd:option('--uniform', 0.1, 'initialize parameters using uniform distribution between -uniform and uniform. -1 means default initialization')
cmd:option('--k', 25, 'how many noise samples to use for NCE')
cmd:option('--continue', '', 'path to model for which training should be continued. Note that current options (except for device, cuda and tiny) will be ignored.')
cmd:option('--Z', -1, 'normalization constant for NCE module (-1 approximates it from first batch).')
-- rnn layer
cmd:option('--seqlen', 5, 'sequence length : back-propagate through time (BPTT) for this many time-steps')
cmd:option('--hiddensize', '{200}', 'number of hidden units used at output of each recurrent layer. When more than one is specified, RNN/LSTMs/GRUs are stacked')
Expand All @@ -42,6 +43,7 @@ cmd:option('--validsize', -1, 'number of valid time-steps used for early stoppin
cmd:option('--savepath', paths.concat(dl.SAVE_PATH, 'rnnlm'), 'path to directory where experiment log (includes model) will be saved')
cmd:option('--id', '', 'id string of this experiment (used to name output file) (defaults to a unique id)')
cmd:option('--tiny', false, 'use train_tiny.th7 training file')
cmd:option('--dontsave', false, 'dont save the model')

cmd:text()
local opt = cmd:parse(arg or {})
Expand All @@ -51,12 +53,44 @@ if not opt.silent then
table.print(opt)
end
opt.id = opt.id == '' and ('gbw' .. ':' .. dl.uniqueid()) or opt.id
opt.version = 4

if opt.cuda then -- do this before building model to prevent segfault
require 'cunn'
cutorch.setDevice(opt.device)
end

local xplog, lm, criterion, targetmodule
if opt.continue ~= '' then
xplog = torch.load(opt.continue)
xplog.opt.cuda = opt.cuda
xplog.opt.device = opt.device
xplog.opt.tiny = opt.tiny
opt = xplog.opt
lm = xplog.model.module
-- prevent re-casting bug
for i,lookup in ipairs(lm:findModules('nn.LookupTableMaskZero')) do
lookup.__input = nil
end
-- backwards compatibility with old NCEModule
if not opt.version then
print"converting old NCEModule"
local nce
for i,ncem in ipairs(lm:findModules('nn.NCEModule')) do
ncem:fastNoise()
ncem.Z = torch.Tensor{-1}
ncem.noiseSample = nn.NCEModule.noiseSample
nce = ncem
end
nce:clearState()
lm.modules[#lm.modules] = nn.Sequencer(nn.MaskZero(nce, 1))
print"done"
end
criterion = xplog.criterion
targetmodule = xplog.targetmodule
assert(opt)
end

--[[ data set ]]--

local trainset, validset, testset = dl.loadGBW({opt.batchsize,opt.batchsize,opt.batchsize}, opt.tiny and 'train_tiny.th7' or nil)
Expand All @@ -67,47 +101,54 @@ end

--[[ language model ]]--

local lm = nn.Sequential()
if not lm then
lm = nn.Sequential()

-- input layer (i.e. word embedding space)
local lookup = nn.LookupTableMaskZero(#trainset.ivocab, opt.hiddensize[1])
lookup.maxnormout = -1 -- prevent weird maxnormout behaviour
lm:add(lookup) -- input is seqlen x batchsize
if opt.dropout > 0 then
lm:add(nn.Dropout(opt.dropout))
end

-- rnn layers
local inputsize = opt.hiddensize[1]
for i,hiddensize in ipairs(opt.hiddensize) do
-- this is a faster version of nnSequencer(nn.FastLSTM(inpusize, hiddensize))
local rnn = nn.SeqLSTM(inputsize, hiddensize)
rnn.maskzero = true
lm:add(rnn)
-- input layer (i.e. word embedding space)
local lookup = nn.LookupTableMaskZero(#trainset.ivocab, opt.hiddensize[1])
lookup.maxnormout = -1 -- prevent weird maxnormout behaviour
lm:add(lookup) -- input is seqlen x batchsize
if opt.dropout > 0 then
lm:add(nn.Dropout(opt.dropout))
end
inputsize = hiddensize
end

lm:add(nn.SplitTable(1))
-- rnn layers
local inputsize = opt.hiddensize[1]
for i,hiddensize in ipairs(opt.hiddensize) do
-- this is a faster version of nnSequencer(nn.FastLSTM(inpusize, hiddensize))
local rnn = nn.SeqLSTM(inputsize, hiddensize)
rnn.maskzero = true
lm:add(rnn)
if opt.dropout > 0 then
lm:add(nn.Dropout(opt.dropout))
end
inputsize = hiddensize
end

lm:add(nn.SplitTable(1))

-- output layer
local unigram = trainset.wordfreq:float()
local ncemodule = nn.NCEModule(inputsize, #trainset.ivocab, opt.k, unigram)
ncemodule:fastNoise()
-- output layer
local unigram = trainset.wordfreq:float()
local ncemodule = nn.NCEModule(inputsize, #trainset.ivocab, opt.k, unigram, opt.Z)

-- NCE requires {input, target} as inputs
lm = nn.Sequential()
:add(nn.ParallelTable()
:add(lm):add(nn.Identity()))
:add(nn.ZipTable()) -- {{x1,x2,...}, {t1,t2,...}} -> {{x1,t1},{x2,t2},...}
-- NCE requires {input, target} as inputs
lm = nn.Sequential()
:add(nn.ParallelTable()
:add(lm):add(nn.Identity()))
:add(nn.ZipTable()) -- {{x1,x2,...}, {t1,t2,...}} -> {{x1,t1},{x2,t2},...}

-- encapsulate stepmodule into a Sequencer
lm:add(nn.Sequencer(nn.TrimZero(ncemodule, 1)))
-- encapsulate stepmodule into a Sequencer
lm:add(nn.Sequencer(nn.MaskZero(ncemodule, 1)))

-- remember previous state between batches
lm:remember()
-- remember previous state between batches
lm:remember()

if opt.uniform > 0 then
for k,param in ipairs(lm:parameters()) do
param:uniform(-opt.uniform, opt.uniform)
end
end
end

if opt.profile then
lm:profile()
Expand All @@ -118,25 +159,21 @@ if not opt.silent then
print(lm)
end

if opt.uniform > 0 then
for k,param in ipairs(lm:parameters()) do
param:uniform(-opt.uniform, opt.uniform)
end
end

--[[ loss function ]]--
if not (criterion and targetmodule) then
--[[ loss function ]]--

local crit = nn.MaskZeroCriterion(nn.NCECriterion(), 0)
local crit = nn.MaskZeroCriterion(nn.NCECriterion(), 0)

-- target is also seqlen x batchsize.
local targetmodule = nn.SplitTable(1)
if opt.cuda then
targetmodule = nn.Sequential()
:add(nn.Convert())
:add(targetmodule)
-- target is also seqlen x batchsize.
targetmodule = nn.SplitTable(1)
if opt.cuda then
targetmodule = nn.Sequential()
:add(nn.Convert())
:add(targetmodule)
end

criterion = nn.SequencerCriterion(crit)
end

local criterion = nn.SequencerCriterion(crit)

--[[ CUDA ]]--

Expand All @@ -149,26 +186,28 @@ end
--[[ experiment log ]]--

-- is saved to file every time a new validation minima is found
local xplog = {}
xplog.opt = opt -- save all hyper-parameters and such
xplog.dataset = 'GoogleBillionWords'
xplog.vocab = trainset.vocab
-- will only serialize params
xplog.model = nn.Serial(lm)
xplog.model:mediumSerial()
xplog.criterion = criterion
xplog.targetmodule = targetmodule
-- keep a log of NLL for each epoch
xplog.trainnceloss = {}
xplog.valnceloss = {}
-- will be used for early-stopping
xplog.minvalnceloss = 99999999
xplog.epoch = 0
if not xplog then
xplog = {}
xplog.opt = opt -- save all hyper-parameters and such
xplog.dataset = 'GoogleBillionWords'
xplog.vocab = trainset.vocab
-- will only serialize params
xplog.model = nn.Serial(lm)
xplog.model:mediumSerial()
xplog.criterion = criterion
xplog.targetmodule = targetmodule
-- keep a log of NLL for each epoch
xplog.trainnceloss = {}
xplog.valnceloss = {}
-- will be used for early-stopping
xplog.minvalnceloss = 99999999
xplog.epoch = 0
paths.mkdir(opt.savepath)
end
local ntrial = 0
paths.mkdir(opt.savepath)

local epoch = 1
opt.lr = opt.startlr
local epoch = xplog.epoch+1
opt.lr = opt.lr or opt.startlr
opt.trainsize = opt.trainsize == -1 and trainset:size() or opt.trainsize
opt.validsize = opt.validsize == -1 and validset:size() or opt.validsize
while opt.maxepoch <= 0 or epoch <= opt.maxepoch do
Expand All @@ -181,20 +220,14 @@ while opt.maxepoch <= 0 or epoch <= opt.maxepoch do
lm:training()
local sumErr = 0
for i, inputs, targets in trainset:subiter(opt.seqlen, opt.trainsize) do
local _ = require 'moses'
assert(not _.isNaN(targets:sum()))
assert(not _.isNaN(inputs:sum()))
targets = targetmodule:forward(targets)
inputs = {inputs, targets}
-- forward
local outputs = lm:forward(inputs)
local err = criterion:forward(outputs, targets)
assert(not _.isNaN(err))
sumErr = sumErr + err
-- backward
local gradOutputs = criterion:backward(outputs, targets)
assert(not _.isNaN(gradOutputs[1][1]:sum()))
assert(not _.isNaN(gradOutputs[1][2]:sum()))
local a = torch.Timer()
lm:zeroGradParameters()
lm:backward(inputs, gradOutputs)
Expand Down Expand Up @@ -270,7 +303,9 @@ while opt.maxepoch <= 0 or epoch <= opt.maxepoch do
xplog.epoch = epoch
local filename = paths.concat(opt.savepath, opt.id..'.t7')
print("Found new minima. Saving to "..filename)
torch.save(filename, xplog)
if not opt.dontsave then
torch.save(filename, xplog)
end
ntrial = 0
elseif ntrial >= opt.earlystop then
print("No new minima found after "..ntrial.." epochs.")
Expand All @@ -282,4 +317,3 @@ while opt.maxepoch <= 0 or epoch <= opt.maxepoch do
collectgarbage()
epoch = epoch + 1
end

Loading

0 comments on commit 990cf07

Please sign in to comment.