Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modern multiprocessing for MCMC NEGFC, small fix to tutorial 05A #658

Merged
merged 12 commits into from
Feb 5, 2025
Merged
8 changes: 3 additions & 5 deletions docs/source/tutorials/05A_fm_planets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"outputs": [],
"source": [
"%matplotlib inline\n",
"from hciplot import plot_frames, plot_cubes\n",
"from hciplot import plot_frames\n",
"from matplotlib.pyplot import *\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
Expand Down Expand Up @@ -937,7 +937,7 @@
],
"source": [
"from vip_hci.preproc import frame_crop\n",
"cropped_frame1 = frame_crop(final_ann_opt, cenxy=xy_test, size=15)"
"cropped_frame1 = frame_crop(final_ann_opt, xy=xy_test, size=15)"
]
},
{
Expand All @@ -953,9 +953,7 @@
]
}
],
"source": [
"cropped_frame2 = frame_crop(fr_pca_emp, cenxy=xy_test, size=15)"
]
"source": "cropped_frame2 = frame_crop(fr_pca_emp, xy=xy_test, size=15)"
},
{
"cell_type": "markdown",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ astropy
photutils
scikit-learn
scikit-image
emcee==2.2.1
emcee
nestle
corner
pandas
Expand Down
2 changes: 1 addition & 1 deletion tests/pre_3_10/test_preproc_recentering.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
mpl.use("Agg")

try:
from IPython.core.display import display, HTML
from IPython.display import display, HTML

def html(s):
display(HTML(s))
Expand Down
275 changes: 141 additions & 134 deletions vip_hci/fm/negfc_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import numpy as np
import os
import emcee
from multiprocessing import cpu_count
import multiprocessing
import inspect
import datetime
import corner
Expand Down Expand Up @@ -699,7 +699,7 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
conv_test: str, optional {'gb','ac'}
Method to check for convergence:
- 'gb' for gelman-rubin test
(http://digitalassets.lib.berkeley.edu/sdtr/ucb/text/305.pdf)
(https://digitalassets.lib.berkeley.edu/sdtr/ucb/text/305.pdf)
- 'ac' for autocorrelation analysis
(https://emcee.readthedocs.io/en/stable/tutorials/autocorr/)
ac_c: float, optional
Expand All @@ -725,7 +725,7 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
Maximum number of steps per walker between two Gelman-Rubin test.
nproc: int or None, optional
The number of processes to use for parallelization. If None, will be set
automatically to half the number of CPUs available.
automatically to the number of CPUs available.
output_dir: str, optional
The name of the output directory which contains the output files in the
case ``save`` is True.
Expand Down Expand Up @@ -814,8 +814,8 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
else:
raise TypeError("Interpolation not recognized.")

if nproc is None:
nproc = cpu_count() // 2 # Hyper-threading doubles the # of cores
if nproc is None: # if the user has not provided nproc, determine the number of processes for Pool to use
nproc = multiprocessing.cpu_count()

# #########################################################################
# Initialization of the variables
Expand Down Expand Up @@ -914,141 +914,148 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

sampler = emcee.EnsembleSampler(nwalkers, dim, lnprob, a=a,
args=([bounds, cube, angs, psfn,
fwhm, annulus_width, ncomp,
aperture_radius, initial_state,
cube_ref, svd_mode, scaling, algo,
delta_rot, fmerit, imlib,
interpolation, collapse,
algo_options, weights, transmission,
mu_sigma, sigma, force_rPA]),
threads=nproc)

if verbosity > 0:
print('emcee Ensemble sampler successful')
start = datetime.datetime.now()
avail_methods = multiprocessing.get_all_start_methods()
if "forkserver" in avail_methods:
multiprocessing.set_start_method("forkserver", force=True) # faster, better
else:
multiprocessing.set_start_method("spawn", force=True) # slower, but available on all platforms

with multiprocessing.Pool(processes=nproc) as pool:
sampler = emcee.EnsembleSampler(nwalkers, dim, lnprob,
pool=pool, moves=emcee.moves.StretchMove(a=a),
args=([bounds, cube, angs, psfn,
fwhm, annulus_width, ncomp,
aperture_radius, initial_state,
cube_ref, svd_mode, scaling, algo,
delta_rot, fmerit, imlib,
interpolation, collapse,
algo_options, weights, transmission,
mu_sigma, sigma, force_rPA]))

# #########################################################################
# Affine Invariant MCMC run
# #########################################################################
if verbosity > 1:
print('\nStart of the MCMC run ...')
print('Step | Duration/step (sec) | Remaining Estimated Time (sec)')
if verbosity > 0:
print('emcee Ensemble sampler successful')
start = datetime.datetime.now()

for k, res in enumerate(sampler.sample(pos, iterations=nIterations)):
elapsed = (datetime.datetime.now()-start).total_seconds()
# #########################################################################
# Affine Invariant MCMC run
# #########################################################################
if verbosity > 1:
if k == 0:
q = 0.5
else:
q = 1
print('{}\t\t{:.5f}\t\t\t{:.5f}'.format(k, elapsed * q,
elapsed * (limit-k-1) * q),
flush=True)

start = datetime.datetime.now()
print('\nStart of the MCMC run ...')
print('Step | Duration/step (sec) | Remaining Estimated Time (sec)')

# ---------------------------------------------------------------------
# Store the state manually in order to handle with dynamical sized chain
# ---------------------------------------------------------------------
# Check if the size of the chain is long enough.
s = chain.shape[1]
if k+1 > s: # if not, one doubles the chain length
empty = np.zeros([nwalkers, 2*s, dim])
chain = np.concatenate((chain, empty), axis=1)
# Store the state of the chain
chain[:, k] = res[0]

# ---------------------------------------------------------------------
# If k meets the criterion, one tests the non-convergence.
# ---------------------------------------------------------------------
criterion = int(np.amin([np.ceil(itermin*(1+fraction)**geom),
lastcheck+np.floor(maxgap)]))
if k == criterion:
for k, res in enumerate(sampler.sample(pos, iterations=nIterations)):
elapsed = (datetime.datetime.now()-start).total_seconds()
if verbosity > 1:
print('\n {} convergence test in progress...'.format(conv_test))

geom += 1
lastcheck = k
if display:
show_walk_plot(chain, labels=labels)

if save and verbosity == 3:
fname = '{d}/{f}_temp_k{k}'.format(d=output_dir,
f=output_file_tmp, k=k)
data = {'chain': sampler.chain,
'lnprob': sampler.lnprobability,
'AR': sampler.acceptance_fraction}
with open(fname, 'wb') as fileSave:
pickle.dump(data, fileSave)

# We only test the rhat if we have reached the min # of steps
if (k+1) >= itermin and konvergence == np.inf:
if conv_test == 'gb':
thr0 = int(np.floor(burnin*k))
thr1 = int(np.floor((1-burnin)*k*0.25))

# We calculate the rhat for each model parameter.
for j in range(dim):
part1 = chain[:, thr0:thr0 + thr1, j].reshape(-1)
part2 = chain[:, thr0 + 3 * thr1:thr0 + 4 * thr1, j
].reshape(-1)
series = np.vstack((part1, part2))
rhat[j] = gelman_rubin(series)
if verbosity > 0:
print(' r_hat = {}'.format(rhat))
cond = rhat <= rhat_threshold
print(' r_hat <= threshold = {} \n'.format(cond), flush=True)
# We test the rhat.
if (rhat <= rhat_threshold).all():
rhat_count += 1
if rhat_count < rhat_count_threshold:
if verbosity > 0:
msg = "Gelman-Rubin test OK {}/{}"
print(msg.format(rhat_count,
rhat_count_threshold))
elif rhat_count >= rhat_count_threshold:
if k == 0:
q = 0.5
else:
q = 1
print('{}\t\t{:.5f}\t\t\t{:.5f}'.format(k, elapsed * q,
elapsed * (limit-k-1) * q),
flush=True)

start = datetime.datetime.now()

# ---------------------------------------------------------------------
# Store the state manually in order to handle with dynamical sized chain
# ---------------------------------------------------------------------
# Check if the size of the chain is long enough.
s = chain.shape[1]
if k+1 > s: # if not, one doubles the chain length
empty = np.zeros([nwalkers, 2*s, dim])
chain = np.concatenate((chain, empty), axis=1)
# Store the state of the chain
chain[:, k] = res[0]

# ---------------------------------------------------------------------
# If k meets the criterion, one tests the non-convergence.
# ---------------------------------------------------------------------
criterion = int(np.amin([np.ceil(itermin*(1+fraction)**geom),
lastcheck+np.floor(maxgap)]))
if k == criterion:
if verbosity > 1:
print('\n {} convergence test in progress...'.format(conv_test))

geom += 1
lastcheck = k
if display:
show_walk_plot(chain, labels=labels)

if save and verbosity == 3:
fname = '{d}/{f}_temp_k{k}'.format(d=output_dir,
f=output_file_tmp, k=k)
data = {'chain': sampler.chain,
'lnprob': sampler.get_log_prob(),
'AR': sampler.acceptance_fraction}
with open(fname, 'wb') as fileSave:
pickle.dump(data, fileSave)

# We only test the rhat if we have reached the min # of steps
if (k+1) >= itermin and konvergence == np.inf:
if conv_test == 'gb':
thr0 = int(np.floor(burnin*k))
thr1 = int(np.floor((1-burnin)*k*0.25))

# We calculate the rhat for each model parameter.
for j in range(dim):
part1 = chain[:, thr0:thr0 + thr1, j].reshape(-1)
part2 = chain[:, thr0 + 3 * thr1:thr0 + 4 * thr1, j
].reshape(-1)
series = np.vstack((part1, part2))
rhat[j] = gelman_rubin(series)
if verbosity > 0:
print(' r_hat = {}'.format(rhat))
cond = rhat <= rhat_threshold
print(' r_hat <= threshold = {} \n'.format(cond), flush=True)
# We test the rhat.
if (rhat <= rhat_threshold).all():
rhat_count += 1
if rhat_count < rhat_count_threshold:
if verbosity > 0:
msg = "Gelman-Rubin test OK {}/{}"
print(msg.format(rhat_count,
rhat_count_threshold))
elif rhat_count >= rhat_count_threshold:
if verbosity > 0:
print('... ==> convergence reached')
konvergence = k
stop = konvergence + supp
else:
rhat_count = 0
elif conv_test == 'ac':
# We calculate the auto-corr test for each model parameter.
if save:
chain_name = "TMP_test_chain{:.0f}.fits".format(k)
write_fits(output_dir+'/'+chain_name, chain[:, :k])
for j in range(dim):
rhat[j] = autocorr_test(chain[:, :k, j])
thr = 1./ac_c
if verbosity > 0:
print('Auto-corr tau/N = {}'.format(rhat))
print('tau/N <= {} = {} \n'.format(thr, rhat < thr), flush=True)
if (rhat <= thr).all():
ac_count += 1
if verbosity > 0:
print('... ==> convergence reached')
konvergence = k
stop = konvergence + supp
msg = "Auto-correlation test passed for all params!"
msg += "{}/{}".format(ac_count, ac_count_thr)
print(msg)
if ac_count >= ac_count_thr:
msg = '\n ... ==> convergence reached'
print(msg)
stop = k
else:
ac_count = 0
else:
rhat_count = 0
elif conv_test == 'ac':
# We calculate the auto-corr test for each model parameter.
raise ValueError('conv_test value not recognized')
# append the autocorrelation factor to file for easy reading
if save:
chain_name = "TMP_test_chain{:.0f}.fits".format(k)
write_fits(output_dir+'/'+chain_name, chain[:, :k])
for j in range(dim):
rhat[j] = autocorr_test(chain[:, :k, j])
thr = 1./ac_c
if verbosity > 0:
print('Auto-corr tau/N = {}'.format(rhat))
print('tau/N <= {} = {} \n'.format(thr, rhat < thr), flush=True)
if (rhat <= thr).all():
ac_count += 1
if verbosity > 0:
msg = "Auto-correlation test passed for all params!"
msg += "{}/{}".format(ac_count, ac_count_thr)
print(msg)
if ac_count >= ac_count_thr:
msg = '\n ... ==> convergence reached'
print(msg)
stop = k
else:
ac_count = 0
else:
raise ValueError('conv_test value not recognized')
# append the autocorrelation factor to file for easy reading
if save:
with open(output_dir + '/MCMC_results_tau.txt', 'a') as f:
f.write(str(rhat) + '\n')
# We have reached the maximum number of steps for our Markov chain.
if k+1 >= stop:
if verbosity > 0:
print('We break the loop because we have reached convergence')
break
with open(output_dir + '/MCMC_results_tau.txt', 'a') as f:
f.write(str(rhat) + '\n')
# We have reached the maximum number of steps for our Markov chain.
if k+1 >= stop:
if verbosity > 0:
print('We break the loop because we have reached convergence')
break

if k == nIterations-1:
if verbosity > 0:
Expand All @@ -1062,7 +1069,7 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
output = {'chain': chain_zero_truncated(chain),
'input_parameters': input_parameters,
'AR': sampler.acceptance_fraction,
'lnprobability': sampler.lnprobability}
'lnprobability': sampler.get_log_prob()}

if output_file is None:
output_file = 'MCMC_results'
Expand All @@ -1076,7 +1083,7 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
timing(start_time)

# reactivate multithreading
ncpus = cpu_count()
ncpus = multiprocessing.cpu_count()
os.environ["MKL_NUM_THREADS"] = str(ncpus)
os.environ["NUMEXPR_NUM_THREADS"] = str(ncpus)
os.environ["OMP_NUM_THREADS"] = str(ncpus)
Expand Down