Skip to content

Commit

Permalink
Implementation of Alias Multinomial for faster Multinomial sampling (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
amartya18x authored and soumith committed Jul 11, 2017
1 parent 01bcef9 commit e17f93a
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 10 deletions.
32 changes: 23 additions & 9 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1269,16 +1269,30 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
wrap("multinomial",
cname("multinomial"),
{{name="IndexTensor", default=true, returned=true, method={default='nil'}},
{name='Generator', default=true},
{name=Tensor},
{name="int"},
{name="boolean", default=false}})

{name='Generator', default=true},
{name=Tensor},
{name="int"},
{name="boolean", default=false}})

wrap("multinomialAliasSetup_",
cname("multinomialAliasSetup"),
{{name=Tensor},
{name="IndexTensor", default=true, returned=true, method={default='nil'}},
{name=Tensor, default=true, returned=true, method={default='nil'}}})

wrap("multinomialAlias_",
cname("multinomialAliasDraw"),
{{name="IndexTensor", default=true, returned=true, method={default='nil'}},
{name='Generator', default=true},
{name="IndexTensor"},
{name=Tensor}
})

for _,f in ipairs({{name='uniform', a=0, b=1},
{name='normal', a=0, b=1},
{name='cauchy', a=0, b=1},
{name='logNormal', a=1, b=2}}) do

{name='normal', a=0, b=1},
{name='cauchy', a=0, b=1},
{name='logNormal', a=1, b=2}}) do
wrap(f.name,
string.format("THRandom_%s", f.name),
{{name='Generator', default=true},
Expand Down
56 changes: 56 additions & 0 deletions doc/maths.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,62 @@ p.multinomial(res, p, n, replacement) -- p.multinomial instead of torch.multinom

This is due to the fact that the result here is of a `LongTensor` type, and we do not define a `torch.multinomial` over long `Tensor`s.

<a name="torch.multinomialAlias()"></a>
### [state] torch.multinomialAliasSetup(probs) ###
### [res] torch.multinomialAlias(output, state)
`state = torch.multinomialAliasSetup(probs)` returns a table `state` consisting of two `tensors` : `probability table` and an `alias table`. This is required once for each `probs` vectors. We can then sample from the multinomial distribution multiple times by consulting these tensors `state` table.

`torch.multinomialAlias(output, state)` returns `output` filled with indices drawn from the multinomial distribution `probs`. `output` itself is filled with the indices and it is not necessary to get the return value of the statement.

The sampling is done through a technique defined in a very simple way in this blog about [The Alias Method](https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/). The paper that describes this technique is present [here](http://www.tandfonline.com/doi/abs/10.1080/00031305.1979.10482697). This can only sample with replacement.

The `output` `Tensor` that is fed into the `multinomialAlias` method need not be contiguous. The `output` tensor can only be a 1d tensor. If you are required to fill a nd tensor enter a 1d view of the same tensor. This method is exceptionally faster than `torch.multinomial` when you want to sample a lot of samples from the same distrbution or sample from the same distribution a large number of times. `torch.multinomial` is faster for sampling few samples from a distribution once because the `multinomialAliasSetup` method takes some time in this case. To see and compare how these two methods differ in speed run `th test/test_aliasMultinomial.lua`.

```lua
th> state = torch.multinomialAliasSetup(probs)
th> state
{
1 : LongTensor - size: 4
2 : DoubleTensor - size: 4
}
th> output = torch.LongTensor(2,3)
th> torch.multinomialAlias(output:view(-1), state)
4
1
2
3
2
2
[torch.LongTensor of size 6]
th> output
4 1 2
3 2 2
[torch.LongTensor of size 2x3]
```

You can also allocate memory and reuse it for the state table.

```
th> state = {torch.LongTensor(), torch.DoubleTensor()}
th> probs = torch.DoubleTensor({0.2, 0.3, 0.5})
th> state = torch.multinomialAliasSetup(probs, state)
th> state
{
1 : LongTensor - size: 3
2 : DoubleTensor - size: 3
}
th> output = torch.LongTensor(7)
th> torch.multinomialAlias(output, state)
2
2
3
1
2
2
2
[torch.LongTensor of size 7]
```

<a name="torch.ones"></a>
### [res] torch.ones([res,] m [,n...]) ###
<a name="torch.ones"></a>
Expand Down
16 changes: 15 additions & 1 deletion init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ require('torch.FFInterface')
require('torch.Tester')
require('torch.TestSuite')
require('torch.test')

function torch.totable(obj)
if torch.isTensor(obj) or torch.isStorage(obj) then
return obj:totable()
Expand Down Expand Up @@ -189,4 +188,19 @@ torch.Tensor.isTensor = torch.isTensor
-- remove this line to disable automatic heap-tracking for garbage collection
torch.setheaptracking(true)

function torch.multinomialAliasSetup(probs, state)
if torch.type(state) == 'table' then
state[1], state[2] = torch.multinomialAliasSetup_(probs, state[1], state[2])
else
state = {}
state[1], state[2] = torch.multinomialAliasSetup_(probs)
end
return state
end

function torch.multinomialAlias(output, state)
torch.DoubleTensor.multinomialAlias_(output, state[1], state[2])
return output
end

return torch
110 changes: 110 additions & 0 deletions lib/TH/generic/THTensorRandom.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,116 @@ void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean,
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_logNormal(_generator, mean, stdv););
}


void THTensor_(multinomialAliasSetup)(THTensor *probs, THLongTensor *J, THTensor *q)
{
long inputsize = THTensor_(nElement)(probs);
long i = 0;
THLongTensor *smaller = THLongTensor_newWithSize1d(inputsize);
THLongTensor *larger = THLongTensor_newWithSize1d(inputsize);
long small_c = 0;
long large_c = 0;
THLongTensor_resize1d(J, inputsize);
THTensor_(resize1d)(q, inputsize);
real *q_data = THTensor_(data)(q);
long *J_data = THLongTensor_data(J);

for(i = 0; i < inputsize; i++)
{
THTensor_fastSet1d(J, i, 0L);
real val = THTensor_fastGet1d(probs, i);
THTensor_fastSet1d(q, i, inputsize*val);

if (inputsize * val < 1.0)
{
THTensor_fastSet1d(smaller, small_c, i);
small_c += 1;
}
else
{
THTensor_fastSet1d(larger, large_c, i);
large_c += 1;
}
}

// Loop through and create little binary mixtures that
// appropriately allocate the larger outcomes over the
// overall uniform mixture.
long large, small;
while(small_c > 0 && large_c > 0)
{
large = THTensor_fastGet1d(larger, large_c-1);
small = THTensor_fastGet1d(smaller, small_c-1);

THTensor_fastSet1d(J, small, large);
q_data[large * q->stride[0]] -= 1.0 - THTensor_fastGet1d(q, small);

if(q_data[large] < 1.0)
{
THTensor_fastSet1d(smaller, small_c-1, large);
large_c -= 1;
}
else
{
THTensor_fastSet1d(larger, large_c-1, large);
small_c -= 1;
}
}

real q_min = THTensor_fastGet1d(q, inputsize-1);
real q_max = q_min;
real q_temp;
for(i=0; i < inputsize; i++)
{
q_temp = THTensor_fastGet1d(q, i);
if(q_temp < q_min)
q_min = q_temp;
else if(q_temp > q_max)
q_max = q_temp;
}
THArgCheckWithCleanup((q_min > 0),
THCleanup(THLongTensor_free(smaller); THLongTensor_free(larger);), 2,
"q_min is less than 0");

if(q_max > 1)
{
for(i=0; i < inputsize; i++)
{
q_data[i*q->stride[0]] /= q_max;
}
}
for(i=0; i<inputsize; i++)
{
// sometimes an large index isn't added to J.
// fix it by making the probability 1 so that J isn't indexed.
if(J_data[i] <= 0)
q_data[i] = 1.0;
}
THLongTensor_free(smaller);
THLongTensor_free(larger);
}
void THTensor_(multinomialAliasDraw)(THLongTensor *self, THGenerator *_generator, THLongTensor *J, THTensor *q)
{
long K = THLongTensor_nElement(J);
long output_nelem = THLongTensor_nElement(self);

int i = 0, _mask=0;
real _q;
long rand_ind, sample_idx, J_sample, kk_sample;
for(i=0; i< output_nelem; i++)
{
rand_ind = (long)THRandom_uniform(_generator, 0, K) ;
_q = THTensor_fastGet1d(q, rand_ind);

_mask = THRandom_bernoulli(_generator, _q);

J_sample = THTensor_fastGet1d(J, rand_ind);

sample_idx = J_sample*(1 -_mask) + (rand_ind+1L) * _mask;

THTensor_fastSet1d(self, i, sample_idx-1L);
}
}
void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement)
{
int start_dim = THTensor_(nDimension)(prob_dist);
Expand Down
2 changes: 2 additions & 0 deletions lib/TH/generic/THTensorRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ TH_API void THTensor_(exponential)(THTensor *self, THGenerator *_generator, doub
TH_API void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma);
TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, double stdv);
TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement);
TH_API void THTensor_(multinomialAliasSetup)(THTensor *prob_dist, THLongTensor *J, THTensor *q);
TH_API void THTensor_(multinomialAliasDraw)(THLongTensor *self, THGenerator *_generator, THLongTensor *J, THTensor *q);
#endif

#if defined(TH_REAL_IS_BYTE)
Expand Down
23 changes: 23 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,29 @@ function torchtest.multinomialwithoutreplacement()
end
end
end
function torchtest.aliasMultinomial()
for i =1,5 do
local n_class = 5
local t=os.time()
torch.manualSeed(t)
local probs = torch.Tensor(n_class):uniform(0,1)
probs:div(probs:sum())
local output = torch.LongTensor(1000, 10000)
local n_samples = output:nElement()
local prob_state = torch.multinomialAliasSetup(probs)
mytester:assert(prob_state[1]:min() > 0, "Index ="..prob_state[1]:min().."alias indices has an index below or equal to 0")
mytester:assert(prob_state[1]:max() <= n_class, prob_state[1]:max().." alias indices has an index exceeding num_class")
local prob_state = torch.multinomialAliasSetup(probs, prob_state)
mytester:assert(prob_state[1]:min() > 0, "Index ="..prob_state[1]:min().."alias indices has an index below or equal to 0(cold)")
mytester:assert(prob_state[1]:max() <= n_class, prob_state[1]:max()..","..prob_state[1]:min().." alias indices has an index exceeding num_class(cold)")
local output = torch.LongTensor(n_samples)
output = torch.multinomialAlias(output, prob_state)
mytester:assert(output:nElement() == n_samples, "wrong number of samples")
mytester:assert(output:min() > 0, "sampled indices has an index below or equal to 0")
mytester:assert(output:max() <= n_class, "indices has an index exceeding num_class")
end

end
function torchtest.multinomialvector()
local n_col = 4
local t=os.time()
Expand Down
40 changes: 40 additions & 0 deletions test/test_aliasMultinomial.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
local tester = torch.Tester()


local function aliasMultinomial()
local n_class = 10000
local probs = torch.Tensor(n_class):uniform(0,1)
probs:div(probs:sum())
local a = torch.Timer()
local state = torch.multinomialAliasSetup(probs)
print("AliasMultinomial setup in "..a:time().real.." seconds(hot)")
a:reset()
state = torch.multinomialAliasSetup(probs, state)
print("AliasMultinomial setup in "..a:time().real.." seconds(cold)")
a:reset()

tester:assert(state[1]:min() >= 0, "Index ="..state[1]:min().."alias indices has an index below or equal to 0")
tester:assert(state[1]:max() <= n_class, state[1]:max().." alias indices has an index exceeding num_class")
local output = torch.LongTensor(1000000)
torch.multinomialAlias(output, state)
local n_samples = output:nElement()
print("AliasMultinomial draw "..n_samples.." elements from "..n_class.." classes ".."in "..a:time().real.." seconds")
local counts = torch.Tensor(n_class):zero()
mult_output = torch.multinomial(probs, n_samples, true)
print("Multinomial draw "..n_samples.." elements from "..n_class.." classes ".." in "..a:time().real.." seconds")
tester:assert(output:min() > 0, "sampled indices has an index below or equal to 0")
tester:assert(output:max() <= n_class, "indices has an index exceeding num_class")
output:apply(function(x)
counts[x] = counts[x] + 1
end)
a:reset()

counts:div(counts:sum())

tester:assert(state[1]:min() >= 0, "Index ="..state[1]:min().."alias indices has an index below or equal to 0")
tester:assert(state[1]:max() <= n_class, state[1]:max().." alias indices has an index exceeding num_class")
tester:eq(probs, counts, 0.001, "probs and counts should be approximately equal")
end

tester:add(aliasMultinomial)
tester:run()

0 comments on commit e17f93a

Please sign in to comment.