-
Notifications
You must be signed in to change notification settings - Fork 47
Using parfor
Brian Lau edited this page Jun 5, 2017
·
6 revisions
Currently working on branch parfor
A simple example demonstrating how to use parfor
to run Stan on a parallel pool with multiple datasets:
model_code = {
'data {'
' int<lower=0> N;'
' int<lower=0,upper=1> y[N];'
'}'
'parameters {'
' real<lower=0,upper=1> theta;'
'}'
'model {'
'for (n in 1:N)'
' y[n] ~ bernoulli(theta);'
'}'
};
% True values for theta
theta = 0.1:.2:1;
% Set this to the number of cores of your machine to avoid warnings
ncores = 4;
tic
ticBytes(gcp);
parfor i = 1:numel(theta)
% Generate some fake data
data = struct('N',10,'y',double(rand(1,10)<theta(i)));
% Instantiate model, using Matlab function 'tempdir' so that all the
% output files for every worker will go to their respective temporary
% directories (which may be different)
sm = StanModel('model_code',model_code,...
'model_name','bernoulli',...
'working_dir',tempdir);
% Call 'stan' with model, the 'chains' and 'iter' are just examples
% 'block' however must be true, or else the worker will pass back the
% result before sampling is completed
% Note that compilation will happen once on each worker (may happen a
% couple times if the parallel pool is local)
fit(i) = stan('fit',sm,'data',data,...
'chains',min(i,ncores),'iter',150000,...
'verbose',true,'block',true);
% Print a check that chains were set properly
fprintf('Model: %s, id: %s, #chains=%g, seed=%g\n',...
fit(i).model.model_name,fit(i).model.id,...
fit(i).model.chains,fit(i).model.seed);
end
tocBytes(gcp)
toc
This should produce something similar to the following:
Starting parallel pool (parpool) using the 'local' profile ... connected to 2 workers.
Stan is sampling with 2 chains...
Stan is sampling with 4 chains...
Model: bernoulli, id: Rf9IULUgVhbbzIyvHYoGU, #chains=2, seed=209978
Model: bernoulli, id: 4GxOFExHVU45BRZSU2AFr, #chains=4, seed=209977
Stan is sampling with 1 chains...
Model: bernoulli, id: DwHvz6PAc0L2IRcQKRvjQe, #chains=1, seed=210954
Stan is sampling with 3 chains...
Model: bernoulli, id: FMN6ovJRgkslCvIpRiwwJC, #chains=3, seed=211006
Stan is sampling with 4 chains...
Model: bernoulli, id: BdjgP4v521XVArOe2NFvnJ, #chains=4, seed=211672
BytesSentToWorkers BytesReceivedFromWorkers
__________________ ________________________
1 13796 3.3822e+07
2 11672 3.3755e+07
Total 25468 6.7578e+07
Elapsed time is 50.871062 seconds.
And comparing the estimates:
theta
0.1000 0.3000 0.5000 0.7000 0.9000
arrayfun(@(x) mean(x.extract.theta),fit)
0.1659 0.2502 0.3332 0.6663 0.8338