Skip to content

Using parfor

Brian Lau edited this page Jun 5, 2017 · 6 revisions

Currently working on branch parfor

A simple example demonstrating how to use parfor to run Stan on a parallel pool with multiple datasets:

model_code = {
'data {'
'    int<lower=0> N;'
'    int<lower=0,upper=1> y[N];'
'}'
'parameters {'
'    real<lower=0,upper=1> theta;'
'}'
'model {'
'for (n in 1:N)'
'    y[n] ~ bernoulli(theta);'
'}'
};

% True values for theta
theta = 0.1:.2:1;

% Set this to the number of cores of your machine to avoid warnings
ncores = 4;

tic
ticBytes(gcp);
parfor i = 1:numel(theta)
   % Generate some fake data
   data = struct('N',10,'y',double(rand(1,10)<theta(i)));
   
   % Instantiate model, using Matlab function 'tempdir' so that all the
   % output files for every worker will go to their respective temporary
   % directories (which may be different)
   sm = StanModel('model_code',model_code,...
      'model_name','bernoulli',...
      'working_dir',tempdir);
   
   % Call 'stan' with model, the 'chains' and 'iter' are just examples
   % 'block' however must be true, or else the worker will pass back the
   % result before sampling is completed
   % Note that compilation will happen once on each worker (may happen a
   % couple times if the parallel pool is local)
   fit(i) = stan('fit',sm,'data',data,...
      'chains',min(i,ncores),'iter',150000,...
      'verbose',true,'block',true);
   
   % Print a check that chains were set properly
   fprintf('Model: %s, id: %s, #chains=%g, seed=%g\n',...
      fit(i).model.model_name,fit(i).model.id,...
      fit(i).model.chains,fit(i).model.seed);
end
tocBytes(gcp)
toc

This should produce something similar to the following:

Starting parallel pool (parpool) using the 'local' profile ... connected to 2 workers.
Stan is sampling with 2 chains...
Stan is sampling with 4 chains...
Model: bernoulli, id: Rf9IULUgVhbbzIyvHYoGU, #chains=2, seed=209978
Model: bernoulli, id: 4GxOFExHVU45BRZSU2AFr, #chains=4, seed=209977
Stan is sampling with 1 chains...
Model: bernoulli, id: DwHvz6PAc0L2IRcQKRvjQe, #chains=1, seed=210954
Stan is sampling with 3 chains...
Model: bernoulli, id: FMN6ovJRgkslCvIpRiwwJC, #chains=3, seed=211006
Stan is sampling with 4 chains...
Model: bernoulli, id: BdjgP4v521XVArOe2NFvnJ, #chains=4, seed=211672
             BytesSentToWorkers    BytesReceivedFromWorkers
             __________________    ________________________

    1        13796                 3.3822e+07              
    2        11672                 3.3755e+07              
    Total    25468                 6.7578e+07              

Elapsed time is 50.871062 seconds.

And comparing the estimates:

theta
    0.1000    0.3000    0.5000    0.7000    0.9000

arrayfun(@(x) mean(x.extract.theta),fit)
    0.1659    0.2502    0.3332    0.6663    0.8338