Skip to content

Commit

Permalink
Merge pull request #157 from mcyc/153-refit-marginal-pix
Browse files Browse the repository at this point in the history
Added a refit functional to refit pixels with marginally good models
  • Loading branch information
mcyc authored Oct 15, 2024
2 parents 0588065 + aa2017b commit a691f60
Showing 1 changed file with 106 additions and 22 deletions.
128 changes: 106 additions & 22 deletions mufasa/master_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import multiprocessing
from spectral_cube import SpectralCube
from astropy import units as u
from skimage.morphology import binary_dilation, square, disk
from skimage.morphology import binary_dilation, remove_small_holes, remove_small_objects, square, disk
import astropy.io.fits as fits
from copy import copy, deepcopy
import gc
Expand Down Expand Up @@ -186,7 +186,7 @@ def get_fits(reg, ncomp, **kwargs):
# functions specific to 2-component fits


def master_2comp_fit(reg, snr_min=0.0, recover_wide=True, planemask=None, updateCnvFits=True, refit_bad_pix=True,
def master_2comp_fit(reg, snr_min=0.0, recover_wide=True, planemask=None, updateCnvFits=True, refit_bad_pix=True, refit_marg=True,
multicore=True):
'''
note: planemask supercedes snr-based mask
Expand All @@ -204,11 +204,16 @@ def master_2comp_fit(reg, snr_min=0.0, recover_wide=True, planemask=None, update
if recover_wide:
refit_2comp_wide(reg, snr_min=recover_snr_min, multicore=multicore)

if refit_marg:
refit_marginal(reg, ncomp=2, lnk_thresh=5, holes_only=False, multicore=True,
method='best_neighbour')

save_best_2comp_fit(reg, multicore=multicore)

return reg



def iter_2comp_fit(reg, snr_min=3.0, updateCnvFits=True, planemask=None, multicore=True, use_cnv_lnk=False,
save_para=True):
proc_name = 'iter_2comp_fit'
Expand Down Expand Up @@ -250,6 +255,7 @@ def iter_2comp_fit(reg, snr_min=3.0, updateCnvFits=True, planemask=None, multico
reg.log_progress(process_name=proc_name, mark_start=False)



def refit_bad_2comp(reg, snr_min=3, lnk_thresh=-5, multicore=True, save_para=True, method='best_neighbour'):
'''
refit pixels where 2 component fits are substantially worse than good one components
Expand Down Expand Up @@ -278,29 +284,56 @@ def refit_bad_2comp(reg, snr_min=3, lnk_thresh=-5, multicore=True, save_para=Tru
logger.info("No pixel was used in attempt to recover bad 2-comp. fits")
return

guesses = copy(ucube.pcubes['2'].parcube)
guesses[guesses == 0] = np.nan
# remove the bad pixels from the fitted parameters
guesses[:, mask] = np.nan
guesses, mask = get_refit_guesses(ucube, mask, ncomp=2, method='best_neighbour', refmap=lnk20)
# re-fit and save the updated model
replace_bad_pix(ucube, mask, snr_min, guesses, None, simpfit=True, multicore=multicore)

if method == 'convolved':
# use astropy convolution to interpolate guesses (we assume bad fits are usually well surrounded by good fits)
kernel = Gaussian2DKernel(2.5 / 2.355)
for i, gmap in enumerate(guesses):
gmap[mask] = np.nan
guesses[i] = convolve(gmap, kernel, boundary='extend')
if save_para:
save_updated_paramaps(reg.ucube, ncomps=[2, 1])

elif method == 'best_neighbour':
# use the nearest neighbour with the highest lnk20 value for guesses
# neighbours.square_neighbour(1) gives the 8 closest neighbours
maxref_coords = neighbours.maxref_neighbor_coords(mask=mask, ref=lnk20, fill_coord=(0, 0),
structure=neighbours.square_neighbour(1))
ys, xs = zip(*maxref_coords)
guesses[:, mask] = guesses[:, ys, xs]
mask = np.logical_and(mask, np.all(np.isfinite(guesses), axis=0))
reg.log_progress(process_name=proc_name, mark_start=False)



def refit_marginal(reg, ncomp, lnk_thresh=5, holes_only=False, multicore=True, save_para=True, method='best_neighbour', **kwargs_marg):
# refit pixels that seem marginaly okay
ucube = reg.ucube

proc_name = f'refit_marginal_{ncomp}_comp'
reg.log_progress(process_name=proc_name, mark_start=True)
ucube = reg.ucube

multicore = validate_n_cores(multicore)
logger.info(f"Begin re-fitting marginal pixels using {multicore} cores")

lnk_maps = reg.ucube.get_all_lnk_maps(ncomp_max=ncomp, rest_model_mask=False, multicore=multicore)
if ncomp == 1:
lnkmap = lnk_maps #lnk10
refmap = lnkmap
elif ncomp == 2:
lnkmap = lnk_maps[2] #lnk21 for thresholding
refmap = lnk_maps[1] #lnk20 for best neighbour (in case lnk21 is high simply the one component fit is poor)

mask = get_marginal_pix(lnkmap, lnk_thresh=lnk_thresh, holes_only=holes_only, **kwargs_marg)
if mask.sum() < 1:
logger.info(f"No pixel was used in attempt to recover marginal {ncomp}-comp. fits")
return

guesses, mask = get_refit_guesses(ucube, mask=mask, ncomp=ncomp, method=method, refmap=lnkmap)

# ensure the mask doesn't extend beyond the original fit and has guesses
mask = np.logical_and(mask, np.isfinite(guesses).all(axis=0))
mask = np.logical_and(mask, reg.ucube.pcubes[str(ncomp)].has_fit)

mask_size = np.sum(mask)
if mask_size > 0:
logger.info(f"Attempting to refit over {mask_size} pixels to recover marginal {ncomp}-comp. fits")
else:
logger.info(f"No pixel was used in attempt to recover marginal {ncomp}-comp. fits")
return

gc.collect()
# re-fit and save the updated model
snr_min=0
replace_bad_pix(ucube, mask, snr_min, guesses, None, simpfit=True, multicore=multicore)

if save_para:
Expand All @@ -309,6 +342,7 @@ def refit_bad_2comp(reg, snr_min=3, lnk_thresh=-5, multicore=True, save_para=Tru
reg.log_progress(process_name=proc_name, mark_start=False)



def refit_swap_2comp(reg, snr_min=3):
ncomp = [1, 2]

Expand Down Expand Up @@ -396,7 +430,7 @@ def refit_2comp_wide(reg, snr_min=3, method='residual', planemask=None, multicor
try:
wide_comp_guess = get_2comp_wide_guesses(reg, window_hwidth=3.5, snr_min=snr_min, savefit=True, planemask=mask)
except SNRMaskError as e:
msg = "Unable to recovere second component from residual. " + e.__str__()
msg = e.__str__() + " No second component recovered from the residual cube."
logger.warning(msg)
return
except StartFitError as e:
Expand Down Expand Up @@ -485,6 +519,56 @@ def replace_rss(ucube, ucube_ref, ncomp, mask):
else:
ucube.get_AICc(ncomp=ncomp, update=True, planemask=mask)

def get_refit_guesses(ucube, mask, ncomp, method='best_neighbour', refmap=None):
#get refit guesses from the surrounding pixels
guesses = copy(ucube.pcubes[str(ncomp)].parcube)
guesses[guesses == 0] = np.nan
# remove the bad pixels from the fitted parameters
guesses[:, mask] = np.nan

if method == 'best_neighbour':
if refmap is None:
raise ValueError("refmap must be provided for the best_neighbour method.")
if not isinstance(refmap, np.ndarray):
raise TypeError(f"{type(refmap)} is the incorrect type for refmap.")
# use the nearest neighbour with the highest lnk20 value for guesses
# neighbours.square_neighbour(1) gives the 8 closest neighbours
maxref_coords = neighbours.maxref_neighbor_coords(mask=mask, ref=refmap, fill_coord=(0, 0),
structure=neighbours.square_neighbour(1))
ys, xs = zip(*maxref_coords)
guesses[:, mask] = guesses[:, ys, xs]
mask = np.logical_and(mask, np.all(np.isfinite(guesses), axis=0))

elif method == 'convolved':
# use astropy convolution to interpolate guesses (we assume bad fits are usually well surrounded by good fits)
kernel = Gaussian2DKernel(2.5 / 2.355)
for i, gmap in enumerate(guesses):
gmap[mask] = np.nan
guesses[i] = convolve(gmap, kernel, boundary='extend')

return guesses, mask


def get_marginal_pix(lnkmap, lnk_thresh=5, holes_only=False, smallest_struct_size=9):
'''
Retrun pixels next to the edge of the structures with >lnk_thresh, or pixels less than lnk_thresh enclosed within the structures
:param lnkmap: the relative log-likelihood map
:param lnk_thresh: the relative log-likelihood thereshold
:param holes_only: return only holes
:param smallest_struct_size: the minimum size of the connected pixels to be considred a good reference structure
:return:
'''

mask = remove_small_objects(lnkmap > lnk_thresh, smallest_struct_size)
mask_nosml = remove_small_holes(mask)

if holes_only:
# returns holes surrounded by pixels with lnk > lnk_thresh
return np.logical_xor(mask_nosml, mask)

else:
mask_nosml = binary_dilation(mask_nosml)
return np.logical_xor(mask_nosml, mask)

def standard_2comp_fit(reg, planemask=None, snr_min=3):
# two compnent fitting method using the moment map guesses method
Expand Down

0 comments on commit a691f60

Please sign in to comment.