diff --git a/.gitignore b/.gitignore index 2dc53ca..874361d 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +data +results \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 6bb36c1..1b85645 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,6 @@ recursive-exclude test recursive-exclude assets -recursive-exclude .github \ No newline at end of file +recursive-exclude .github +recursive-exclude utils +recursive-exclude experiments +recursive-exclude examples \ No newline at end of file diff --git a/ecut/annealing.py b/ecut/annealing.py index f7cff15..4cc7a1f 100644 --- a/ecut/annealing.py +++ b/ecut/annealing.py @@ -18,8 +18,8 @@ class MorphAnneal: - def __init__(self, swc: ListNeuron, min_gap=10., radius_gap=0, min_step=5., min_step_ratio=.5, step_size=.5, - epsilon=1e-7, res=(1,1,1), drop_len=40.): + def __init__(self, swc: ListNeuron, min_gap=10., radius_gap=.5, min_step=5., min_step_ratio=.5, step_size=.5, + epsilon=1e-7, res=(1,1,1), drop_len=20.): self.morph = Morphology(swc) self._min_gap = min_gap self._eps = epsilon diff --git a/ecut/base_types.py b/ecut/base_types.py index d6bea05..1f37b91 100644 --- a/ecut/base_types.py +++ b/ecut/base_types.py @@ -20,9 +20,7 @@ def __init__(self): self.end1_adj = set() # multiple other nodes, connected or close self.end2_adj = set() self.traversed = set() # for the simplification of the problem - - self.source = None # the neuron it belongs to - self.likelihood = 0 # the likelihood of this fragment belonging to this neuron + self.source = {} # the neuron it belongs to, source id -> likelihood class BaseNode: @@ -49,13 +47,24 @@ def update(self, **kwargs): class BaseCut: - def __init__(self, swc: ListNeuron, soma: list[int], verbose=False): + def __init__(self, swc: ListNeuron, soma: list[int], res, likelihood_thr=None, verbose=False): + """ + + :param swc: swc tree + :param soma: list of soma id + :param res: resolution in x, y, z + :param likelihood_thr: the minimum likelihood allowed for a fragment to be attached to a neuron, left as None + to attach it to just the biggest. When multiple sources share a common or big enough likelihood, all of them will be considered. + :param verbose: + """ self._verbose = verbose self._swc = dict([(t[0], t) for t in swc]) + self.res = np.array(res) self._soma = soma self._fragment: dict[int, BaseFragment] = {} self._fragment_trees: dict[int, dict[int, BaseNode]] = {} self._problem: pulp.LpProblem | None = None + self._likelihood_thr: float = likelihood_thr @property def swc(self): @@ -77,36 +86,59 @@ def export_swc(self, partition=True): :return: an swc or a dict of swc """ if not partition: - tree = [list(t) for t in self._swc.values()] + tree = dict([(t[0], list(t)) for t in self._swc.values()]) + tag = dict(zip(self._soma, range(len(self._soma)))) for frag in self._fragment.values(): for i in frag.nodes: - tree[i][1] = frag.source - tree = [tuple(t) for t in tree] + a = list(frag.source.values()) + b = list(frag.source.keys()) + a = np.argmax(a) + tree[i][1] = tag[b[a]] + tree = [tuple(t) for t in tree.values()] return tree trees = dict([(i, {(-1, 1): (1, *self._swc[i][1:6], -1)}) for i in self._soma]) for frag_id, frag in self._fragment.items(): - frag_node = self._fragment_trees[frag.source][frag_id] - nodes = self._fragment[frag_id].nodes - if not frag_node.reverse: - nodes = nodes[::-1] - par_frag_id = frag_node.parent - if par_frag_id == -1: - last_id = -1, 1 - else: - par_frag_node = self._fragment_trees[frag.source][par_frag_id] - par_nodes = self._fragment[par_frag_id].nodes - if par_frag_node.reverse: - last_id = par_frag_id, par_nodes[-1] + candid = [] + a = list(frag.source.values()) + b = list(frag.source.keys()) + if self._likelihood_thr is None: # max only mode + m = None + for i in np.argsort(a)[::-1]: + if m is not None and m > a[i]: + break + candid.append(b[i]) + m = a[i] + else: # thresholding mode, bigger than this will all be considered + for i in np.argsort(a)[::-1]: + if a[i] < self._likelihood_thr: + break + candid.append(b[i]) + + # for each candid source, append the frag nodes + for src in candid: + frag_node = self._fragment_trees[src][frag_id] + nodes = self._fragment[frag_id].nodes + if not frag_node.reverse: + nodes = nodes[::-1] + par_frag_id = frag_node.parent + if par_frag_id == -1: + last_id = -1, 1 else: - last_id = par_frag_id, par_nodes[0] - tree = trees[frag.source] - for i in nodes: - n = list(self._swc[i]) - n[6] = last_id - n[0] = len(tree) + 1 - tree[(frag_id, i)] = tuple(n) - last_id = frag_id, i + par_frag_node = self._fragment_trees[src][par_frag_id] + par_nodes = self._fragment[par_frag_id].nodes + if par_frag_node.reverse: + last_id = par_frag_id, par_nodes[-1] + else: + last_id = par_frag_id, par_nodes[0] + + tree = trees[src] + for i in nodes: + n = list(self._swc[i]) + n[6] = last_id + n[0] = len(tree) + 1 + tree[(frag_id, i)] = tuple(n) + last_id = frag_id, i for s, t in trees.items(): for k, v in t.items(): @@ -126,9 +158,17 @@ def _linear_programming(self): # finding variables for fragment/soma pairs that require solving scores = {} # var_i_s, i: fragment id, s: soma id for i, frag in self._fragment.items(): - scores[i] = {} - for s in frag.traversed: - scores[i][s] = pulp.LpVariable(f'Score_{i}_{s}', 0) # non-negative + if len(frag.traversed) > 1: # mixed sources + scores[i] = {} + for s in frag.traversed: + scores[i][s] = pulp.LpVariable(f'Score_{i}_{s}', 0) # non-negative + elif len(frag.traversed) == 1: + scores[i] = {} + for s in frag.traversed: + scores[i][s] = pulp.LpVariable(f'Score_{i}_{s}', 1, 1) # const + else: + pass + # raise ValueError('') # objective func: cost * score self._problem += pulp.lpSum( @@ -136,7 +176,6 @@ def _linear_programming(self): self._fragment_trees[s][i].cost * score for s, score in frag_vars.items() ) for i, frag_vars in scores.items() ), "Global Penalty" - # constraints for i, frag_vars in scores.items(): self._problem += (pulp.lpSum(score for score in frag_vars.values()) == 1, @@ -147,20 +186,17 @@ def _linear_programming(self): self._problem += score <= scores[p][s], \ f"Tree Topology Enforcement for Score_{i}_{s}" - self._problem.solve() + self._problem.solve(pulp.PULP_CBC_CMD(msg=0)) + + for frag in self._fragment.values(): + frag.source = dict.fromkeys(frag.traversed, 1) for variable in self._problem.variables(): frag_id, src = variable.name.split('_')[1:] frag_id, src = int(frag_id), int(src) frag = self._fragment[frag_id] - if frag.source is None or frag.likelihood < variable.varValue: - frag.source = src - frag.likelihood = variable.varValue - - for frag in self._fragment.values(): - if frag.source is None: - frag.source = list(frag.traversed)[0] - frag.likelihood = 1 + assert src in frag.source + frag.source[src] = variable.varValue if self._verbose: print("Finished linear programming.") diff --git a/ecut/error_prune.py b/ecut/error_prune.py index 6714abc..8facabd 100644 --- a/ecut/error_prune.py +++ b/ecut/error_prune.py @@ -3,28 +3,25 @@ from scipy.spatial import distance_matrix from scipy.interpolate import interp1d from .morphology import Morphology +from sklearn.decomposition import PCA class ErrorPruning: - def __init__(self, res=(.25, .25, 1.), soma_radius=10., anchor_reach=(2., 10.), gap_thr_ratio=1., epsilon=1e-7): + def __init__(self, res=(.25, .25, 1.), soma_radius=10., anchor_dist=5., epsilon=1e-7): """ :param res: image resolution in micrometers, (x, y, z) :param soma_radius: expected soma radius in micrometers, within which errors are not counted - :param anchor_reach: the distances of the anchor to the branch node, the anchor is meant to accurately - estimate angles. (near, far). near: the near end of the anchor, far: the far end of the anchor, - :param gap_thr_ratio: + :param anchor_dist: the distances of the anchor to the branch node :param epsilon: value lower than this will be regarded as 0 """ self._res = res self._soma_radius = soma_radius - self._near_anchor = anchor_reach[0] - self._far_anchor = anchor_reach[1] - self._gap_thr_ratio = gap_thr_ratio + self._far_anchor = anchor_dist self._eps = epsilon - def _length(self, p1, p2=(0, 0, 0), axis=None): + def _length(self, p1: list | np.ndarray, p2=(0, 0, 0), axis=None): if not isinstance(p1, np.ndarray): p1 = np.array(p1) return np.linalg.norm((p1 - p2) * self._res, axis=axis) @@ -41,8 +38,7 @@ def _vector_angles(self, p, ch): out = [*map(lambda x: math.acos(max(min(x, 1), -1)) * 180 / math.pi, cos_ch)] return np.array(out) - def _find_point(self, morph: Morphology, pt: np.ndarray, idx: np.ndarray, is_parent: bool, - dist_thr: float, pt_rad: float, return_center_point: bool): + def _find_point(self, morph: Morphology, ct: np.ndarray, idx: np.ndarray, is_parent: bool, dist_thr: float, ct_rad: float): """ Find the point of exact `dist` to the start pt on tree structure. args are: - pt: the start point, [coordinate] @@ -51,13 +47,21 @@ def _find_point(self, morph: Morphology, pt: np.ndarray, idx: np.ndarray, is_par if a furcation points encounted, then break - morph: Morphology object for current tree - dist: distance threshold - - return_center_point: whether to return the point with exact distance or - geometric point of all traced nodes """ + # init d = 0 - pts = [pt] - rad = [pt_rad] + if is_parent and morph.pos_dict[idx][6] != -1: + pts = [np.array(morph.pos_dict[idx][2:5])] + rad = [np.array(morph.pos_dict[idx][5])] + idx = morph.pos_dict[idx][6] + elif not is_parent and idx in morph.unifurcation: + pts = [np.array(morph.pos_dict[idx][2:5])] + rad = [np.array(morph.pos_dict[idx][5])] + idx = morph.child_dict[idx][0] + else: + pts = [ct] + rad = [ct_rad] while True: new_p = np.array(morph.pos_dict[idx][2:5]) new_r = morph.pos_dict[idx][5] @@ -76,24 +80,9 @@ def _find_point(self, morph: Morphology, pt: np.ndarray, idx: np.ndarray, is_par else: idx = morph.child_dict[idx][0] - # interpolate to find the exact point - dd = d - dist_thr - if dd < 0: - pt_a = new_p - else: # extrapolate - dcur = self._length(new_p, pts[-1]) - ratio = (dcur - dd) / (dcur + self._eps) - pt_a = pts[-1] + (new_p - pts[-1]) * ratio - r_a = rad[-1] + (new_r - rad[-1]) * ratio - pts.append(pt_a) - rad.append(r_a) - - if return_center_point: - pt_a = np.mean(pts, axis=0) - - return pt_a, pts, rad + return pts, rad, idx - def _get_anchors(self, morph: Morphology, ind: list[int] | int, dist_thr: float, step_size=0.5): + def _get_anchors(self, morph: Morphology, ind: list[int] | int, dist_thr: float): """ get anchors for a set of swc nodes to calculate angles, suppose they are one, their center is their mean coordinate, @@ -129,43 +118,40 @@ def _get_anchors(self, morph: Morphology, ind: list[int] | int, dist_thr: float, protrude = np.array(list(protrude)) # com_node == center can cause problem for spline # for finding anchor_p, you must input sth different from the center to get the right pt list - if self._length(center, morph.pos_dict[com_node][2:5]) <= self._eps: - p = morph.pos_dict[com_node][6] - # p can be -1 if the com_node is root - # but when this happens, com_node can hardly == center - # this case is dispelled when finding crossings - else: - p = com_node - anchor_p, pts_p, rad_p = self._find_point(morph, center, p, True, dist_thr, center_radius, False) - res = [self._find_point(morph, center, i, False, dist_thr, center_radius, False) for i in protrude] - anchor_ch, pts_ch, rad_ch = [i[0] for i in res], [i[1] for i in res], [i[2] for i in res] - gap_thr = np.mean(rad_p) * self._gap_thr_ratio - interp_ch = [] - for pts in pts_ch: - pp = [pts[0]] - j = 1 - dist_cum = [0] - while j < len(pts): - new_d = self._length(pts[j], pp[-1]) - if new_d > self._eps and new_d + dist_cum[-1] != dist_cum[-1]: - dist_cum.append(dist_cum[-1] + new_d) - pp.append(pts[j]) - j += 1 - if len(pp) > 1: - f = interp1d(dist_cum, pp, 'quadratic' if len(pp) > 2 else 'linear', 0, fill_value='extrapolate') - interp_ch.append(f) - step = step_size - while step <= dist_thr: - pts = [i(step) for i in interp_ch] - if len(pts) < 2: - break - gap = distance_matrix(pts, pts) - gap = np.median(gap[np.triu_indices_from(gap, 1)]) - if gap > gap_thr: - break - center = np.mean(pts, axis=0) - step += step_size - return center, anchor_p, anchor_ch, protrude, rad_p, rad_ch + pts_p, rad_p, last_p = self._find_point(morph, center, com_node, True, dist_thr, center_radius) + res = [self._find_point(morph, center, i, False, dist_thr, center_radius) for i in protrude] + pts_ch, rad_ch, last_ch = [i[0] for i in res], [i[1] for i in res], [i[2] for i in res] + return com_node, pts_p, pts_ch, protrude, rad_p, rad_ch, last_p, last_ch + + @staticmethod + def line_fit_pca(pts_list: list[np.ndarray]) -> np.ndarray: + """ + fit 3D points to a straight line. + :param pts_list: a list of 3D connected points + :return: a 3D vector fitted to the list + """ + pca = PCA(n_components=1) + pca.fit(pts_list) + line_direction = pca.components_[0] + temp = pts_list[-1] - pts_list[0] + if temp.dot(line_direction) < 0: + line_direction = -line_direction + return line_direction + + def get_angle(self, pts_list1: list[np.ndarray], pts_list2: list[np.ndarray]): + """ + The angle between 2 vectors (fitted from 2 point lists), but supplementary. + the vectors share the start point, but to make it fit for scoring, its supplementary is returned. + so a smaller angle means a more straight connection. + + :param pts_list1: a list of 3D points for one branch + :param pts_list2: a list of 3D points for another branch + :return: an angle in arc + """ + vec1 = self.line_fit_pca(pts_list1) * self._res + vec2 = self.line_fit_pca(pts_list2) * self._res + cos = vec1.dot(vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) + return math.acos(max(min(cos, 1), -1)) * 180 / math.pi def branch_prune(self, morph, angle_thr=80, radius_amp=1.5): """ @@ -182,19 +168,10 @@ def branch_prune(self, morph, angle_thr=80, radius_amp=1.5): if self._length(cs, morph.pos_dict[n][2:5]) <= self._soma_radius: continue # branch, no soma, away from soma - center, anchor_p, anchor_ch, protrude, rad_p, rad_ch = self._get_anchors(morph, n, self._far_anchor) - _, near_p, near_ch, _, _, _ = self._get_anchors(morph, n, self._near_anchor) - vec_p = np.array(anchor_p) - near_p - vec_ch = np.array(anchor_ch) - near_ch - if self._length(vec_p) < self._eps: - vec_p = np.array(anchor_p) - center - mask = self._length(vec_ch, axis=-1) < self._eps - vec_ch[mask] = vec_ch[mask] - center - - # strange anchor distance can't be considered for pruning - angles = self._vector_angles(vec_p, vec_ch) - radius_p = np.median(rad_p) - radius_ch = np.array([np.median(rad) for rad in rad_ch]) + _, pts_p, pts_ch, protrude, rad_p, rad_ch, _, _ = self._get_anchors(morph, n, self._far_anchor) + angles = np.array([self.get_angle(pts_p, c) for c in pts_ch]) + radius_p = np.mean(rad_p) + radius_ch = np.array([np.mean(rad) for rad in rad_ch]) rm_ind |= set(protrude[(angles < angle_thr) | (radius_ch > radius_p * radius_amp)]) return rm_ind @@ -231,57 +208,96 @@ def _find_mega_crossing(self, morph: Morphology, dist_thr): # merge return [set.union(*[set(t) for t in chains if t[-1] == head]) for head in np.unique([i[-1] for i in chains])] - def crossover_prune(self, morph: Morphology, dist_thr=2., angle_thr=120, check_bif=False): + def crossover_prune(self, morph: Morphology, dist_thr=2., angle_thr1=60, angle_thr2=90, check_bif=False, + no_multi=True, short_tips_thr=5.): """ Prune crossovers by angle. :param morph: morphology wrapped swc tree :param dist_thr: the max distance between nearby branch nodes in a mega crossover. - :param angle_thr: branches less than this angle will only be removed when aligned with another branch + :param angle_thr1: branches less than this angle will take away a best fit branch + :param angle_thr2: branches less than this angle will take away a best fit branch than parent :param check_bif: if checking bifurcation, in this mode only bifurcation will be checked and pruning is not forced. + :param no_multi: ensure no multifurcation, start removing from tips + :param short_tips_thr: drop short tips below this threshold before pruning. :return: nodes to prune. """ crossings = self._find_mega_crossing(morph, dist_thr) cs = np.array(morph.pos_dict[morph.idx_soma][2:5]) if check_bif: - to_check = [i for i in morph.bifurcation - set.union(*crossings) - if self._length(cs, morph.pos_dict[i][2:5]) > self._soma_radius] + if len(crossings) > 0: + crossings = set.union(*crossings) + else: + crossings = set() + to_check = [i for i in morph.bifurcation - crossings if self._length(cs, morph.pos_dict[i][2:5]) > self._soma_radius] else: to_check = crossings rm_ind = set() for x in to_check: # angle - center, anchor_p, anchor_ch, protrude, _, _ = self._get_anchors(morph, x, self._far_anchor) - _, near_p, near_ch, _, _, _ = self._get_anchors(morph, x, self._near_anchor) - vec_p = np.array(anchor_p) - near_p - vec_ch = np.array(anchor_ch) - near_ch - if self._length(vec_p) < self._eps: - vec_p = np.array(anchor_p) - center - mask = self._length(vec_ch, axis=-1) < self._eps - vec_ch[mask] = vec_ch[mask] - center - - angles = self._vector_angles(vec_p, vec_ch) + com_node, pts_p, pts_ch, protrude, rad_p, rad_ch, last_p, last_ch = self._get_anchors(morph, x, self._far_anchor) + rad_p = np.mean(rad_p) + rad_ch = [np.mean(c) for c in rad_ch] + rad_diff = [abs(c - rad_p) for c in rad_ch] + angles = np.array([self.get_angle(pts_p, c) for c in pts_ch]) rm = set() + + for i, c, pts in zip(protrude, last_ch, pts_ch): + if i in rm: + continue + if c in morph.tips: + l = self._length(pts[1:], pts[:-1], axis=1).sum() + if l < short_tips_thr: + rm.add(i) + order = np.argsort(angles) for i in order: # starting from the worst angle if protrude[i] in rm: continue - if angles[i] <= angle_thr: - new_angles = self._vector_angles(vec_ch[i], vec_ch) - for k in np.flip(np.argsort(new_angles ** 2 / angles)): - if new_angles[k] > angles[k] and protrude[k] not in rm: - rm |= set(protrude[[i, k]]) + if angles[i] < angle_thr1: + new_angles = np.array([self.get_angle(pts_ch[i], c) for c in pts_ch]) + for k in np.argsort(angles - new_angles): + new_rad_diff = abs(rad_ch[i] - rad_ch[k]) + if new_rad_diff < rad_diff[k] and protrude[k] not in rm: + rm.add(protrude[i]) + rm.add(protrude[k]) + break + elif angles[i] < angle_thr2: + new_angles = np.array([self.get_angle(pts_ch[i], c) for c in pts_ch]) + for k in np.argsort(angles - new_angles): + new_rad_diff = abs(rad_ch[i] - rad_ch[k]) + if new_angles[k] > angles[k] and new_rad_diff < rad_diff[k] and protrude[k] not in rm: + rm.add(protrude[i]) + rm.add(protrude[k]) + break else: break - left = len(angles) - len(rm) - if left > 2: # ensure bifurcation - for i in order: - if protrude[i] in rm: - continue - rm.add(protrude[i]) - left -= 1 - if left <= 2: - break + if no_multi: + left = len(angles) - len(rm) + if left > 2: # ensure bifurcation + # first, remove short tips + for i, c in zip(protrude, last_ch): + if i in rm: + continue + if c in morph.tips: + rm.add(i) + left -= 1 + if left <= 2: + break + else: + for i in np.argsort(rad_diff)[::-1]: + if protrude[i] in rm: + continue + rm.add(protrude[i]) + left -= 1 + if left <= 2: + break + # remove upstream of protrudes + for i in list(rm): + i = morph.pos_dict[i][6] + while i != -1 and (i in morph.unifurcation or set(morph.child_dict[i]).issubset(rm)): + rm.add(i) + i = morph.pos_dict[i][6] rm_ind |= rm return rm_ind diff --git a/ecut/gcut_utils/distribution.py b/ecut/gcut_utils/distribution.py index 778333f..14605f5 100644 --- a/ecut/gcut_utils/distribution.py +++ b/ecut/gcut_utils/distribution.py @@ -69,8 +69,8 @@ def create_cdf(self): cdf_coefficients = pdf_integral.c cdf_coefficients[-1] = constant self.cdf = np.poly1d(cdf_coefficients) - print('cdf: cdf(0) = {}, cdf(pi) = {}'.format(self.cdf(self.scale(0)), - self.cdf(self.scale(np.pi)))) + # print('cdf: cdf(0) = {}, cdf(pi) = {}'.format(self.cdf(self.scale(0)), + # self.cdf(self.scale(np.pi)))) def scale(self, val): return (val - self._mean) / self._std diff --git a/ecut/graph_cut.py b/ecut/graph_cut.py index 6c5523b..29546ac 100644 --- a/ecut/graph_cut.py +++ b/ecut/graph_cut.py @@ -3,7 +3,7 @@ from .swc_handler import get_child_dict from sklearn.neighbors import KDTree from ._queue import PriorityQueue -from .base_types import BaseCut +from .base_types import BaseCut, ListNeuron from .graph_metrics import EnsembleMetric, EnsembleNode, EnsembleFragment @@ -18,8 +18,8 @@ class ECut(BaseCut): """ - def __init__(self, swc: list[tuple], soma: list[int], children: dict[set] = None, - adjacency: dict[int, set] | float = 5., metric=EnsembleMetric(), *args, **kwargs): + def __init__(self, swc: ListNeuron, soma: list[int], children: dict[set] = None, res=(1., 1., 1.), + adjacency: dict[int, set] | tuple[float] = (5., 10), metric=EnsembleMetric(), *args, **kwargs): """ :param swc: swc tree, whose id should match the line number @@ -28,13 +28,13 @@ def __init__(self, swc: list[tuple], soma: list[int], children: dict[set] = None :param adjacency: close non-connecting neighbours :param metric: the metric to compute the cost on each fragment """ - super().__init__(swc, soma, *args, **kwargs) + super().__init__(swc, soma, res, *args, **kwargs) self._metric = metric self._children = self._get_children() if children is None else children if isinstance(adjacency, dict): self._adjacency = adjacency - else: - self._adjacency = self._get_adjacency(adjacency) + elif isinstance(adjacency, tuple): + self._adjacency = self._get_adjacency(*adjacency) self._end2frag: dict[int, set[int]] | None = None def _get_children(self) -> dict[int, set]: @@ -49,19 +49,26 @@ def _get_children(self) -> dict[int, set]: children[t[0]] = set() return children - def _get_adjacency(self, dist: float) -> dict[int, set]: + def _get_adjacency(self, dist1: float, dist2: float) -> dict[int, set]: """ Generate an adjacency map from the current tree. Parent and children are excluded. - :param dist: the distance threshold to consider connection between 2 nodes. + :param dist1: the distance threshold to consider connection between 2 nodes. + :param dist2: the distance threshold to consider connection between 2 critical nodes. :return: the adjacency dictionary """ - kd = KDTree([t[2:5] for t in self._swc.values()]) + kd = KDTree([np.array(t[2:5]) * self.res for t in self._swc.values()]) + crits = [t for t in self._swc.values() if len(self._children[t[0]]) != 1] + kd_c = KDTree([np.array(t[2:5]) * self.res for t in crits]) keys = np.array(list(self._swc.keys())) - inds, dists = kd.query_radius([t[2:5] for t in self._swc.values()], dist, return_distance=True) + keys_c = np.array([t[0] for t in crits]) + inds = kd.query_radius([np.array(t[2:5]) * self.res for t in self._swc.values()], dist1) + inds_c = kd_c.query_radius([np.array(t[2:5]) * self.res for t in crits], dist2) adjacency = {} - for k, i, d in zip(self._swc.values(), inds, dists): - adjacency[k[0]] = set(keys[i[d < dist]]) - {k[0], k[6]} - self._children[k[0]] + for k, i in zip(self._swc.values(), inds): + adjacency[k[0]] = set(keys[i]) - {k[0], k[6]} - self._children[k[0]] + for k, i in zip(crits, inds_c): + adjacency[k[0]] |= set(keys_c[i]) - {k[0]} # ensure it's undirected graph for k, v in adjacency.items(): for i in v: @@ -124,7 +131,7 @@ def _extract_fragment(self): # and non-connecting but close fragment ends will also be considered for k, v in self._fragment.items(): end1 = v.nodes[0] - v.end1_adj = self._end2frag[end1] - {k} # omit self + v.end1_adj = self._end2frag[end1] - {k} # omit self for i in self._adjacency[end1]: # find adjacent nodes v.end1_adj |= self._end2frag[i] # add any frag related end2 = v.nodes[-1] diff --git a/ecut/graph_metrics.py b/ecut/graph_metrics.py index 067c7cc..3f639d6 100644 --- a/ecut/graph_metrics.py +++ b/ecut/graph_metrics.py @@ -23,8 +23,8 @@ def __init__(self, id): class EnsembleMetric(BaseMetric): - def __init__(self, gof_weight=1., angle_weight=1., radius_weight=1., anchor_dist=20., avg_branch_len=100., - distribution=Distribution(), epsilon=1e-10): + def __init__(self, gof_weight=1., angle_weight=4., radius_weight=2., anchor_dist=20., avg_branch_len=50., + distribution=Distribution(), epsilon=1e-7, soma_radius=20.): """ :param gof_weight: the weight of the global gof metric :param angle_weight: the weight of the local angle metric @@ -41,34 +41,43 @@ def __init__(self, gof_weight=1., angle_weight=1., radius_weight=1., anchor_dist self._anchor_dist = anchor_dist self.avg_branch_len = avg_branch_len self._epsilon = epsilon + self._soma_radius = soma_radius distribution.load_distribution() def init_fragment(self, cut, frag: EnsembleFragment): # calculate path length pts_list = np.array([cut.swc[i][2:5] for i in frag.nodes]) - frag.path_len = np.linalg.norm(pts_list[1:] - pts_list[:-1], axis=1).sum() + frag.path_len = np.linalg.norm((pts_list[1:] - pts_list[:-1]) * cut.res, axis=1).sum() + + def _get_len(self, pts_list, res): + pts_list = np.array(pts_list) + diff = (pts_list[1:] - pts_list[:-1]) * res + return np.linalg.norm(diff, axis=1).sum() def __call__(self, cut, soma, frag_par, frag_ch, reverse): # the angle calculation is based on the pca pc1 of the two point lists # there's a case where pc1 DNE, it returns a vector of (1,0,0) pts_par, radius_par = self._path_upstream(cut, soma, frag_par) pts_ch, radius_ch = self._path_within(cut, frag_ch, not reverse) - angle = self.get_angle(pts_par, pts_ch) / np.pi * self._angle_weight - radius = np.mean(radius_ch) / (np.mean(radius_ch) + np.mean(radius_par)) * self._radius_weight + par_len = self._get_len(pts_par, cut.res) + ch_len = self._get_len(pts_ch, cut.res) + conf = np.sqrt(par_len * ch_len) / self._anchor_dist + angle = self.get_angle(pts_par, pts_ch, cut.res) / np.pi + radius = max(np.mean(radius_ch) - np.mean(radius_par), 0) / np.mean(radius_ch) pts_list = [cut.swc[i][2:5] for i in cut.fragment[frag_ch].nodes] if not reverse: pts_list = pts_list[::-1] frag_node: EnsembleNode = cut.fragment_trees[soma][frag_par] frag: EnsembleFragment = cut.fragment[frag_ch] - ret = {'path_dist': frag_node.path_dist + frag.path_len} - frag_gof = self.get_gof(pts_list, cut.swc[soma][2:5]) - pseudo_order = self.avg_branch_len / ret['path_dist'] # farther branches will be more even in probability - frag_gof_prob = self._distribution.probability(frag_gof) * min(1, np.log(1 + pseudo_order)) + frag_gof = self.get_gof(pts_list, cut.swc[soma][2:5], cut.res) + pseudo_order = frag_node.path_dist / self.avg_branch_len # farther branches will be more even in probability + frag_gof_prob = self._distribution.probability(frag_gof) * min(1, np.log(1 + 1 / (pseudo_order + self._epsilon))) # no suppressing short branches + ret = {'path_dist': frag_node.path_dist + frag.path_len} ret['gof_cost'] = frag_node.gof_cost + (1 - frag_gof_prob) * frag.path_len - ret['cost'] = angle * self._angle_weight + radius * self._radius_weight + \ - ret['gof_cost'] / ret['path_dist'] * self._gof_weight # equals avg gof along the path + ret['cost'] = (angle * self._angle_weight * + radius * self._radius_weight) * conf + \ + ret['gof_cost'] / max(self._soma_radius, ret['path_dist']) * self._gof_weight # equals avg gof along the path return ret def _path_upstream(self, cut, soma: int, frag_id: int): @@ -109,7 +118,7 @@ def _path_within(self, cut, frag_id: int, reverse: bool, path_dist=0., return_di pts_list.append(np.array(cut.swc[i][2:5])) radius_list.append(cut.swc[i][5]) if len(pts_list) > 1: - path_dist += np.linalg.norm(pts_list[-2] - pts_list[-1]) + path_dist += np.linalg.norm((pts_list[-2] - pts_list[-1]) * cut.res) if path_dist > self._anchor_dist: break # stop when exceeding the anchor dist if return_distance: @@ -127,9 +136,12 @@ def line_fit_pca(pts_list: list[np.ndarray]) -> np.ndarray: pca = PCA(n_components=1) pca.fit(pts_list) line_direction = pca.components_[0] + temp = pts_list[-1] - pts_list[0] + if temp.dot(line_direction) < 0: + line_direction = -line_direction return line_direction - def get_angle(self, pts_list1: list[np.ndarray], pts_list2: list[np.ndarray]): + def get_angle(self, pts_list1: list[np.ndarray], pts_list2: list[np.ndarray], res): """ The angle between 2 vectors (fitted from 2 point lists), but supplementary. the vectors share the start point, but to make it fit for scoring, its supplementary is returned. @@ -139,9 +151,9 @@ def get_angle(self, pts_list1: list[np.ndarray], pts_list2: list[np.ndarray]): :param pts_list2: a list of 3D points for another branch :return: an angle in arc """ - vec1 = -EnsembleMetric.line_fit_pca(pts_list1) - vec2 = EnsembleMetric.line_fit_pca(pts_list2) - vec3 = pts_list2[0] - pts_list1[0] + vec1 = -self.line_fit_pca(pts_list1) * res + vec2 = self.line_fit_pca(pts_list2) * res + vec3 = (pts_list2[0] - pts_list1[0]) * res if np.linalg.norm(vec3) > self._epsilon: cos1 = vec1.dot(vec3) / (np.linalg.norm(vec1) * np.linalg.norm(vec3)) cos2 = vec2.dot(vec3) / (np.linalg.norm(vec2) * np.linalg.norm(vec3)) @@ -150,12 +162,12 @@ def get_angle(self, pts_list1: list[np.ndarray], pts_list2: list[np.ndarray]): cos = vec1.dot(vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) return np.arccos(np.clip(cos, -1, 1)) - def get_gof(self, pts_list: list[tuple[float, float, float]], soma: np.ndarray) -> float: + def get_gof(self, pts_list: list[tuple[float, float, float]], soma: np.ndarray, res) -> float: pts_list = np.array(pts_list) - tan = pts_list[1:] - pts_list[:-1] + tan = (pts_list[1:] - pts_list[:-1]) * res tan_norm = np.linalg.norm(tan, axis=1, keepdims=True) tan /= tan_norm + self._epsilon - ps = pts_list[:-1] - soma + ps = (pts_list[:-1] - soma) * res ps_norm = np.linalg.norm(ps, axis=1, keepdims=True) ps /= ps_norm + self._epsilon proj = np.clip(np.sum(tan * ps, axis=1), -1, 1) diff --git a/ecut/soma_detection.py b/ecut/soma_detection.py index c72e439..ef321e8 100644 --- a/ecut/soma_detection.py +++ b/ecut/soma_detection.py @@ -97,7 +97,7 @@ def predict(self, img, res: list[float], thr=None) -> list[np.ndarray]: class DetectTiledImage: def __init__(self, tile_size=(256, 256, 64), omit_border=(16, 16, 4), merge_dist=15, - base_detector=DetectImage(), nproc=1): + base_detector=DetectImage(), nproc=None): """ :param tile_size: The size of a single tile, indexed by x, y, z. @@ -133,19 +133,28 @@ def predict(self, img: np.ndarray, res: list[float]) -> list[np.ndarray]: x = np.linspace(hf[2], img.shape[2] - hf[2], steps[2], dtype=int) jobs = [] - with Pool(self._nproc) as p: + prefilter = [] + if self._nproc is not None: + with Pool(self._nproc) as p: + for zz in z: + for yy in y: + for xx in x: + s = (zz, yy, xx) - hf + e = (zz, yy, xx) + hf + tile = img[s[0]: e[0], s[1]: e[1], s[2]: e[2]] + jobs.append(p.apply_async(DetectTiledImage.process_find_soma, + (self._find_soma, tile, res, thr, s, self._omit_border, self._tile_size))) + + for i in tqdm(jobs): + prefilter.extend(i.get()) + else: for zz in z: for yy in y: for xx in x: s = (zz, yy, xx) - hf e = (zz, yy, xx) + hf tile = img[s[0]: e[0], s[1]: e[1], s[2]: e[2]] - jobs.append(p.apply_async(DetectTiledImage.process_find_soma, - (self._find_soma, tile, res, thr, s, self._omit_border, self._tile_size))) - - prefilter = [] - for i in tqdm(jobs): - prefilter.extend(i.get()) + prefilter.extend(self.process_find_soma(self._find_soma, tile, res, thr, s, self._omit_border, self._tile_size)) if len(prefilter) == 0: return [] @@ -156,7 +165,7 @@ def predict(self, img: np.ndarray, res: list[float]) -> list[np.ndarray]: class DetectTracingMask: - def __init__(self, min_radius=2, merge_dist=15, diam_range=(5, 20)): + def __init__(self, min_radius=2., merge_dist=15., diam_range=(5., 20.)): """ :param min_radius: the minimum radius of the swc nodes to consider, in micrometer @@ -178,7 +187,8 @@ def predict(self, swc: ListNeuron, res: list[float]) -> list[np.ndarray]: candid = [t for t in swc if t[5] * sf >= self._min_radius] pos = np.array([t[2:5] for t in candid]) rad = np.array([t[5] for t in candid]) - + if len(pos) == 0: + return [] db = DBSCAN(self._merge_dist, min_samples=1) db.fit(pos * res) labels = db.labels_ # Get cluster labels. diff --git a/ecut/swc_handler.py b/ecut/swc_handler.py index 384958c..b0ceecd 100644 --- a/ecut/swc_handler.py +++ b/ecut/swc_handler.py @@ -10,6 +10,7 @@ import re import numpy as np from copy import deepcopy +from queue import SimpleQueue NEURITE_TYPES = { @@ -385,3 +386,31 @@ def get_soma_from_swc(swcfile): soma_str = re.search('.* -1\n', fp.read()).group() soma = soma_str.split() return soma + + +def sort_swc(tree: list, root=1): + ch_dict = get_child_dict(tree) + ind = get_index_dict(tree) + count = 1 + temp = list(tree[ind[root]]) + temp[6] = -1 + temp[0] = count + new_tree = [tuple(temp)] + new_dict = {root: 1} + + q = SimpleQueue() + q.put_nowait(root) + + while not q.empty(): + head = tree[ind[q.get_nowait()]][0] + if head in ch_dict: + for i in ch_dict[head]: + count += 1 + temp = list(tree[ind[i]]) + temp[0] = count + temp[6] = new_dict[head] + new_tree.append(tuple(temp)) + new_dict[i] = count + q.put_nowait(i) + return new_tree + diff --git a/example/app2_processing.py b/example/app2_processing.py index 272a977..0896df4 100644 --- a/example/app2_processing.py +++ b/example/app2_processing.py @@ -9,33 +9,36 @@ if __name__ == '__main__': # tree = swc_handler.parse_swc('../test/data/gcut_input.swc_sorted.swc') - tree = swc_handler.parse_swc(r'D:\rectify\my_app2\18452_26569_3425_5509.swc') - - # detect soma - d = DetectTracingMask(5) - soma = d.predict(tree, [.3, .3, 1.]) + tree = swc_handler.parse_swc(r'D:\rectify\my_app2\17302_14358_42117_2799.swc') + tree = [t for t in tree if not (t[1] == t[2] == t[3] == 0)] + # tree = swc_handler.parse_swc(r'D:\rectify\my_app2\15257_16445_16836_4489.swc') + maxr = max([t[5] for t in tree]) * .3 + rad = max(maxr * .5, 5.) + centers = DetectTracingMask(rad, 20.).predict(tree, [.3, .3, 1]) # anneal a = MorphAnneal(tree) tree = a.run() + # graph cut + if len(centers) < 1: + centers = [[512, 512, 128]] kd = KDTree([t[2:5] for t in tree]) - inds = kd.query(soma, return_distance=False) + inds = kd.query(centers, return_distance=False) inds = [tree[i[0]][0] for i in inds] print(inds) - - # graph cut e = ECut(tree, inds) e.run() trees = e.export_swc() - # pruning for k, v in trees.items(): - p = ErrorPruning([.25, .25, 1], anchor_reach=(5., 20.)) + v = swc_handler.sort_swc(v) + p = ErrorPruning([.3,.3,1], anchor_dist=20., soma_radius=10.) morph = Morphology(v) - a = p.branch_prune(morph, 45, 2) - b = p.crossover_prune(morph, 5, 90) - # c = p.crossover_prune(morph, check_bif=True) - t = swc_handler.prune(v, a | b) - swc_handler.write_swc(t, f'../test/data/ whole_{k}.swc') \ No newline at end of file + a = p.branch_prune(morph, 60, 1.5) + b = p.crossover_prune(morph, 2, 60, 90, short_tips_thr=10., no_multi=False) + c = p.crossover_prune(morph, 2, 60, 90, check_bif=True, short_tips_thr=10.) + v = swc_handler.prune(v, a | b | c) + swc_handler.write_swc(v, f'../test/data/multi_{k}.swc') + \ No newline at end of file diff --git a/example/batch_1891.py b/example/batch_1891.py index 47a8fbf..31d0eef 100644 --- a/example/batch_1891.py +++ b/example/batch_1891.py @@ -1,63 +1,63 @@ -from ecut import swc_handler -from ecut.annealing import MorphAnneal -from ecut.graph_cut import ECut -from ecut.soma_detection import DetectTracingMask -from sklearn.neighbors import KDTree -from ecut.error_prune import ErrorPruning -from ecut.morphology import Morphology -from traceback import print_exc - - -def main(args): - in_path, out_path = args - try: - tree = [t for t in swc_handler.parse_swc(in_path) if not (t[1] == t[2] == t[3] == 0)] - - # detect soma - d = DetectTracingMask(3) - soma = d.predict(tree, [.3, .3, 1]) - - # anneal - a = MorphAnneal(tree) - tree = a.run() - - # map soma - kd = KDTree([t[2:5] for t in tree]) - inds = kd.query(soma, return_distance=False) - inds = [tree[i[0]][0] for i in inds] - - # graph cut - if len(inds) > 1: - e = ECut(tree, inds) - e.run() - trees = e.export_swc() - else: - trees = {0: tree} - - # pruning - for k, v in trees.items(): - p = ErrorPruning([.3, .3, 1], anchor_reach=(5., 20.)) - morph = Morphology(v) - a = p.branch_prune(morph, 45, 2) - b = p.crossover_prune(morph, 5, 90) - # c = p.crossover_prune(morph, check_bif=True) - t = swc_handler.prune(v, a | b) - swc_handler.write_swc(t, str(out_path) + f'_{k}.swc') - except: - print_exc() - print(in_path) - - -if __name__ == '__main__': - from pathlib import Path - from tqdm import tqdm - from multiprocessing import Pool - indir = Path('D:/rectify/my_app2') - outdir = Path('D:/rectify/pruned') - outdir.mkdir(exist_ok=True) - files = sorted(indir.glob('*.swc')) - outfiles = [outdir / f.name for f in files] - arglist = [*zip(files, outfiles)] - with Pool(12) as p: - for i in tqdm(p.imap(main, arglist), total=len(arglist)): +from ecut import swc_handler +from ecut.annealing import MorphAnneal +from ecut.graph_cut import ECut +from ecut.soma_detection import DetectTracingMask +from sklearn.neighbors import KDTree +from ecut.error_prune import ErrorPruning +from ecut.morphology import Morphology +from traceback import print_exc + + +def main(args): + in_path, out_path = args + try: + tree = [t for t in swc_handler.parse_swc(in_path) if not (t[1] == t[2] == t[3] == 0)] + + # detect soma + d = DetectTracingMask(3) + soma = d.predict(tree, [.3, .3, 1]) + + # anneal + a = MorphAnneal(tree) + tree = a.run() + + # map soma + kd = KDTree([t[2:5] for t in tree]) + inds = kd.query(soma, return_distance=False) + inds = [tree[i[0]][0] for i in inds] + + # graph cut + if len(inds) > 1: + e = ECut(tree, inds) + e.run() + trees = e.export_swc() + else: + trees = {0: tree} + + # pruning + for k, v in trees.items(): + p = ErrorPruning([.3, .3, 1], anchor_reach=(5., 20.)) + morph = Morphology(v) + a = p.branch_prune(morph, 45, 2) + b = p.crossover_prune(morph, 5, 90) + # c = p.crossover_prune(morph, check_bif=True) + t = swc_handler.prune(v, a | b) + swc_handler.write_swc(t, str(out_path) + f'_{k}.swc') + except: + print_exc() + print(in_path) + + +if __name__ == '__main__': + from pathlib import Path + from tqdm import tqdm + from multiprocessing import Pool + indir = Path('D:/rectify/my_app2') + outdir = Path('D:/rectify/pruned') + outdir.mkdir(exist_ok=True) + files = sorted(indir.glob('*.swc')) + outfiles = [outdir / f.name for f in files] + arglist = [*zip(files, outfiles)] + with Pool(12) as p: + for i in tqdm(p.imap(main, arglist), total=len(arglist)): pass \ No newline at end of file diff --git a/experiment/batch_1891.py b/experiment/batch_1891.py new file mode 100644 index 0000000..f591c84 --- /dev/null +++ b/experiment/batch_1891.py @@ -0,0 +1,82 @@ +from ecut import swc_handler +from ecut.annealing import MorphAnneal +from ecut.graph_cut import ECut +from ecut.soma_detection import * +from sklearn.neighbors import KDTree +from ecut.error_prune import ErrorPruning +from ecut.morphology import Morphology +from traceback import print_exc +from pathlib import Path +from v3dpy.loaders import PBD + +img_dir = Path(r"D:\rectify\crop_8bit") + + +def main(args): + in_path, out_path = args + try: + tree = [t for t in swc_handler.parse_swc(in_path) if not (t[1] == t[2] == t[3] == 0)] + res = [.3, .3, 1.] + + # detect soma + # img = PBD().load(r"D:\rectify\crop_8bit\18453_9442_3817_6561.v3dpbd")[0] + # centers_list = [] + # centers = DetectImage().predict(img, res) + # centers_list.append(centers) + # centers = DetectTiledImage([300, 300, 200]).predict(img, res) + # centers_list.append(centers) + # maxr = max([t[5] for t in tree]) * res[0] + # centers = DetectTracingMask(maxr * .75, maxr * 3).predict(tree, res) + # centers_list.append(centers) + # centers = DetectDistanceTransform().predict(img, res) + # centers_list.append(centers) + # centers = DetectTiledImage(base_detector=DetectDistanceTransform()).predict(img, res) + # centers_list.append(centers) + # centers = soma_consensus(*centers_list, res=res) + + maxr = max([t[5] for t in tree]) * res[0] + rad = max(maxr * .5, 5.) + centers = DetectTracingMask(rad, 20.).predict(tree, res) + + # anneal + a = MorphAnneal(tree) + tree = a.run() + + # graph cut + if len(centers) < 1: + centers = [[512, 512, 128]] + kd = KDTree([t[2:5] for t in tree]) + inds = kd.query(centers, return_distance=False) + inds = [tree[i[0]][0] for i in inds] + e = ECut(tree, inds) + e.run() + trees = e.export_swc() + + # pruning + for k, v in trees.items(): + v = swc_handler.sort_swc(v) + # p = ErrorPruning(res, anchor_dist=20., soma_radius=10.) + # morph = Morphology(v) + # a = p.branch_prune(morph, 60, 1.5) + # b = p.crossover_prune(morph, 2, 60, 90, short_tips_thr=10., no_multi=False) + # c = p.crossover_prune(morph, 2, 60, 90, check_bif=True, short_tips_thr=10.) + # v = swc_handler.prune(v, a | b | c) + swc_handler.write_swc(v, str(out_path) + f'_{k}.swc') + except: + print_exc() + print(in_path) + + +if __name__ == '__main__': + from tqdm import tqdm + from multiprocessing import Pool + + indir = Path('D:/rectify/my_app2') + outdir = Path('D:/rectify/pruned_3') + outdir.mkdir(exist_ok=True) + files = sorted(indir.glob('*.swc')) + outfiles = [outdir / f.name for f in files] + arglist = [*zip(files, outfiles)] + with Pool(12) as p: + for i in tqdm(p.imap(main, arglist), total=len(arglist)): + pass \ No newline at end of file diff --git a/experiment/eval.py b/experiment/eval.py new file mode 100644 index 0000000..28d77cb --- /dev/null +++ b/experiment/eval.py @@ -0,0 +1,85 @@ +""" +Comparing APP2 reconstruction before and after pruning + +categorized as sparse (863) and dense image blocks, using the tracing result of NIEND enhanced images. + +Here are the pairs: + +sparse against GS +sparse prune against GS +dense against GS +dense prune against GS + +""" + +# evaluate different reconstruction against gold standard + +from utils.metrics import DistanceEvaluation +import pandas as pd +from pathlib import Path +import sys +import os + + +wkdir = Path(r"D:\rectify") +pruned_path = wkdir / 'pruned_3' + + + +class HidePrint: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + + +def main(name): + man = wkdir / 'manual' / name + before = wkdir / 'my_app2' / name + afters = pruned_path.glob(f'{before.stem}*') + de = DistanceEvaluation(15) + with HidePrint(): + before = de.run(before, man) if before.exists() else None + ret = { + 'before_recall': 1 - before[2, 1] if before is not None else 0, + 'before_precision': 1 - before[2, 0] if before is not None else 1, + } + for after in afters: + after = de.run(after, man) if after.exists() else None + with HidePrint(): + recall = 1 - after[2, 1] if after is not None else 0 + precision = 1 - after[2, 0] if after is not None else 1 + if 'after_recall' not in ret or ret['after_recall'] < recall: + ret['after_recall'] = recall + ret['after_precision'] = precision + return ret + + +if __name__ == '__main__': + from multiprocessing import Pool + from tqdm import tqdm + + # get the sparse and dense labels + files = [i.name for i in (wkdir / 'manual').glob('*.swc')] + tab = pd.read_csv(wkdir / 'filter.csv', index_col=0) + + # main('18457_14455_13499_5478.swc') + with Pool(14) as p: + # sparse + sparse = [*filter(lambda f: tab.at[f, 'sparse'] == 1, files)] + res = [] + for r in tqdm(p.imap(main, sparse), total=len(sparse)): + res.append(r) + pd.DataFrame.from_records(res, index=sparse).to_csv('../results/eval_sparse.csv') + + # dense + dense = [*filter(lambda f: tab.at[f, 'sparse'] == 0, files)] + res = [] + for r in tqdm(p.imap(main, dense), total=len(dense)): + res.append(r) + pd.DataFrame.from_records(res, index=dense).to_csv('../results/eval_dense.csv') + + diff --git a/experiment/plot.ipynb b/experiment/plot.ipynb new file mode 100644 index 0000000..394725f --- /dev/null +++ b/experiment/plot.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Plot precision & recall" + ], + "metadata": { + "collapsed": false + }, + "id": "8348ff5edb3bd621" + }, + { + "cell_type": "code", + "source": [ + "import pandas as pd\n", + "from matplotlib import ticker\n", + "import seaborn as sns\n", + "\n", + "\n", + "df1 = pd.read_csv(\"../results/eval_sparse.csv\", index_col=0)\n", + "df2 = pd.read_csv(\"../results/eval_dense.csv\", index_col=0)\n", + "df1['before_f1'] = 2 * (df1['before_precision'] * df1['before_recall']) / (df1['before_precision'] + df1['before_recall'] + .00001)\n", + "df1['after_f1'] = 2 * (df1['after_precision'] * df1['after_recall']) / (df1['after_precision'] + df1['after_recall'] + .00001)\n", + "df2['before_f1'] = 2 * (df2['before_precision'] * df2['before_recall']) / (df2['before_precision'] + df2['before_recall'] + .00001)\n", + "df2['after_f1'] = 2 * (df2['after_precision'] * df2['after_recall']) / (df2['after_precision'] + df2['after_recall'] + .00001)\n", + "df1['type'] = 'sparse'\n", + "df2['type'] = 'dense'\n", + "df = pd.concat([df1, df2], axis=0)\n", + "prec = df.reset_index().melt(id_vars=['index', 'type'], var_name='stat', value_name='value', value_vars=['before_precision', 'after_precision'])\n", + "recall = df.reset_index().melt(id_vars=['index', 'type'], var_name='stat', value_name='value', value_vars=['before_recall', 'after_recall'])\n", + "f1 = df.reset_index().melt(id_vars=['index', 'type'], var_name='stat', value_name='value', value_vars=['before_f1', 'after_f1'])\n", + "\n", + "\n", + "def convert_pvalue_to_asterisks(pvalue):\n", + " if pvalue <= 0.0001:\n", + " return \"****\"\n", + " elif pvalue <= 0.001:\n", + " return \"***\"\n", + " elif pvalue <= 0.01:\n", + " return \"**\"\n", + " elif pvalue <= 0.05:\n", + " return \"*\"\n", + " return \"ns\"\n", + "\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-24T11:27:36.002419Z", + "start_time": "2024-04-24T11:27:35.974900Z" + } + }, + "id": "41ba084b8b450e5c", + "outputs": [], + "execution_count": 5 + }, + { + "cell_type": "code", + "source": [ + "from scipy.stats import ttest_ind\n", + "\n", + "def test(ax, x1, x2, y, h, fs, a, b):\n", + " ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1, c=\"k\")\n", + " stat,p_value = ttest_ind(a, b)\n", + " ax.text((x1+x2)*.5, y - h, convert_pvalue_to_asterisks(p_value), ha='center', va='bottom', color=\"k\")\n", + "\n", + "# plot precision\n", + "sns.set(font_scale=1, style='white')\n", + "import matplotlib.pyplot as plt\n", + "fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=300)\n", + "plt.subplots_adjust(wspace=0.4)\n", + "sns.despine(fig, top=True, right=True)\n", + "\n", + "# precision\n", + "ax = sns.barplot(data=prec, hue='stat', x='type', y='value', ax=axs[1], legend=False)\n", + "ax.set_xticks([*range(2)], ['Sparse', 'Dense'])\n", + "ax.set_ylabel(None)\n", + "ax.set_title('Precision')\n", + "ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1, decimals=0))\n", + "ax.tick_params(left=True, direction='out')\n", + "ax.set_xlabel(None)\n", + "# test(ax, 0, 4, 1.15 - .05, .02, 'small', df['before_precision'], df['my_precision'])\n", + "# test(ax, 3, 4, 1.15 - .12, .02, 'small', df['after_precision'], df['my_precision'])\n", + "# ax.text(0, m['raw_precision']+.03, f\"{m['raw_precision']*100:.1f}%\", ha='center', va='bottom', fontsize=15)\n", + "# ax.text(1, m['ada_precision']+.03, f\"{m['ada_precision']*100:.1f}%\", ha='center', va='bottom', fontsize=15)\n", + "# ax.text(2, m['mul_precision']+.03, f\"{m['mul_precision']*100:.1f}%\", ha='center', va='bottom', fontsize=15)\n", + "# ax.text(3, m['guo_precision']+.03, f\"{m['guo_precision']*100:.1f}%\", ha='center', va='bottom', fontsize=15)\n", + "# ax.text(4, m['my_precision']+.03, f\"{m['my_precision']*100:.1f}%\", ha='center', va='bottom', fontsize=15)\n", + "\n", + "# recall\n", + "ax = sns.barplot(data=recall, hue='stat', x='type', y='value', ax=axs[2])\n", + "ax.set_xticks([*range(2)], ['Sparse', 'Dense'])\n", + "ax.set_ylabel(None)\n", + "ax.set_title('Recall')\n", + "ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1, decimals=0))\n", + "ax.tick_params(left=True, direction='out')\n", + "ax.set_xlabel(None)\n", + "\n", + "# f1\n", + "ax = sns.barplot(data=f1, hue='stat', x='type', y='value', ax=axs[0], legend=False)\n", + "ax.set_xticks([*range(2)], ['Sparse', 'Dense'])\n", + "ax.set_ylabel(None)\n", + "ax.set_title('F1')\n", + "ax.set_ylim(0, 1)\n", + "ax.tick_params(left=True, direction='out')\n", + "ax.set_xlabel(None)\n", + "handles, labels = plt.gca().get_legend_handles_labels()\n", + "plt.legend(handles=handles[:2], labels=['Before', 'After'], loc='lower right')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-24T12:47:04.832898Z", + "start_time": "2024-04-24T12:47:03.934179Z" + } + }, + "id": "2883af7f9a2b9634", + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 9 + }, + { + "cell_type": "code", + "outputs": [ + { + "data": { + "text/plain": "([], [])" + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "handles, labels" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-25T12:41:19.201746Z", + "start_time": "2024-03-25T12:41:19.193747Z" + } + }, + "id": "62be2d4d86bf327", + "execution_count": 32 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/test/test_prune.py b/test/test_prune.py index fc3ddc4..547e526 100644 --- a/test/test_prune.py +++ b/test/test_prune.py @@ -9,7 +9,7 @@ def setUp(self): self.swc = parse_swc('data/anneal_output.swc') def test_prune(self): - pruner = ErrorPruning([.25, .25, 1], anchor_reach=(5., 20.)) + pruner = ErrorPruning([.25, .25, 1], anchor_dist=(5., 20.)) morph = Morphology(self.swc) a = pruner.branch_prune(morph, 45) b = pruner.crossover_prune(morph, 5) diff --git a/test/test_soma.py b/test/test_soma.py index b94d101..643fb2b 100644 --- a/test/test_soma.py +++ b/test/test_soma.py @@ -62,8 +62,17 @@ def test5_tile_dt(self): plt.show() def test6_consensus(self): + path = r"D:\rectify\crop_8bit\18453_9442_3817_6561.v3dpbd" + img = PBD().load(path)[0] centers_list = [] centers = DetectImage().predict(self.img, self.res) + fig, ax = plt.subplots() + ax.imshow(img.max(axis=0), cmap='gray') + for p in centers: + ax.plot(p[2], p[1], '.r', markersize=15) + plt.show() + plt.title('Image') + centers_list.append(centers) centers = DetectTiledImage([300, 300, 200], nproc=16).predict( self.img, [.25, .25, 1.]) @@ -77,8 +86,7 @@ def test6_consensus(self): centers_list.append(centers) centers = soma_consensus(*centers_list, res=self.res) print(centers) - path = r"D:\rectify\crop_8bit\18453_9442_3817_6561.v3dpbd" - img = PBD().load(path)[0] + fig, ax = plt.subplots() ax.imshow(img.max(axis=0), cmap='gray') for p in centers: diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/math_utils.py b/utils/math_utils.py new file mode 100644 index 0000000..1d36720 --- /dev/null +++ b/utils/math_utils.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python + +#================================================================ +# Copyright (C) 2021 Yufeng Liu (Braintell, Southeast University). All rights reserved. +# +# Filename : math_utils.py +# Author : Yufeng Liu +# Date : 2021-07-19 +# Description : +# +#================================================================ + +import math +import numpy as np +from scipy.spatial import distance_matrix +from sklearn.neighbors import KDTree + + +def calc_included_angles_from_vectors(vecs1, vecs2, return_rad=False, epsilon=1e-7, spacing=None, return_cos=False): + if vecs1.ndim == 1: + vecs1 = vecs1.reshape((1,-1)) + if vecs2.ndim == 1: + vecs2 = vecs2.reshape((1,-1)) + + if spacing is not None: + spacing_reshape = np.array(spacing).reshape(1,-1) + # rescale vectors according to spacing + vecs1 = vecs1 * spacing_reshape + vecs2 = vecs2 * spacing_reshape + + inner = (vecs1 * vecs2).sum(axis=1) + norms = np.linalg.norm(vecs1, axis=1) * np.linalg.norm(vecs2, axis=1) + cos_ang = inner / (norms + epsilon) + + if return_cos: + return_val = cos_ang + else: + rads = np.arccos(np.clip(cos_ang, -1, 1)) + if return_rad: + return_val = rads + else: + return_val = np.rad2deg(rads) + return return_val + + +def calc_included_angles_from_coords(anchor_coords, coords1, coords2, return_rad=False, epsilon=1e-7, spacing=None, return_cos=False): + anchor_coords = np.array(anchor_coords) + coords1 = np.array(coords1) + coords2 = np.array(coords2) + v1 = coords1 - anchor_coords + v2 = coords2 - anchor_coords + angs = calc_included_angles_from_vectors( + v1, v2, return_rad=return_rad, + epsilon=epsilon, spacing=spacing, + return_cos=return_cos) + return angs + + +def memory_safe_min_distances(voxels1, voxels2, num_thresh=50000, return_index=False): + # verified + nv1 = len(voxels1) + nv2 = len(voxels2) + if (nv1 > num_thresh) or (nv2 > num_thresh): + # use block wise calculation + vq1 = [voxels1[i*num_thresh:(i+1)*num_thresh] for i in range(int(math.ceil(nv1/num_thresh)))] + vq2 = [voxels2[i*num_thresh:(i+1)*num_thresh] for i in range(int(math.ceil(nv2/num_thresh)))] + + dists1 = np.ones(nv1) * 1000000. + dists2 = np.ones(nv2) * 1000000. + if return_index: + min_indices1 = np.ones(nv1) * -1 + min_indices2 = np.ones(nv2) * -1 + for i,v1 in enumerate(vq1): + idx00 = i * num_thresh + idx01 = i * num_thresh + len(v1) + for j,v2 in enumerate(vq2): + idx10 = j * num_thresh + idx11 = j * num_thresh + len(v2) + + d = distance_matrix(v1, v2) + dmin1 = d.min(axis=1) + dmin0 = d.min(axis=0) + dists1[idx00:idx01] = np.minimum(dmin1, dists1[idx00:idx01]) + dists2[idx10:idx11] = np.minimum(dmin0, dists2[idx10:idx11]) + if return_index: + dargmin1 = np.argmin(d, axis=1) + dargmin0 = np.argmin(d, axis=0) + mask1 = np.nonzero(dmin1 < dists1[idx00:idx01]) + min_indices1[idx00:idx01][mask1[0]] = dargmin1[mask1[0]] + idx00 + mask0 = np.nonzero(dmin0 < dists2[idx10:idx11]) + min_indices2[idx10:idx11][mask0[0]] = dargmin0[mask0[0]] + idx10 + else: + pdist = distance_matrix(voxels1, voxels2) + dists1 = pdist.min(axis=1) + dists2 = pdist.min(axis=0) + if return_index: + min_indices1 = pdist.argmin(axis=1) + min_indices2 = pdist.argmin(axis=0) + + if return_index: + return dists1, dists2, min_indices1, min_indices2 + else: + return dists1, dists2 + + +def min_distances_between_two_sets(voxels1, voxels2, topk=1, reciprocal=True, return_index=False): + """ + We should use kd-tree instead of brute-force method for large-scale data inputs. Arguments are: + @params voxels1: coordinates of points, np.ndarray in shape[N, 3] + @params voxels2: coordinates of points, np.ndarray in shape[M, 3] + @params topk: the number of top-ranking match + @params reciprocal: whether to calculate 2->1, except for 1->2 + @params return_index: whehter to return the indices of points with minimal distances + """ + tree2 = KDTree(voxels2, leaf_size=2) + dmin1, imin1 = tree2.query(voxels1, k=topk) + if reciprocal: + tree1 = KDTree(voxels1, leaf_size=2) + dmin2, imin2 = tree1.query(voxels2, k=topk) + if return_index: + return dmin1, dmin2, imin1, imin2 + else: + return dmin1, dmin2 + else: + if return_index: + return dmin1, imin1 + else: + return dmin1 + + + diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..744bf6d --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +import os + +#================================================================ +# Copyright (C) 2023 Yufeng Liu (Braintell, Southeast University). All rights reserved. +# +# Filename : metrics.py +# Author : Yufeng Liu +# Date : 2023-04-04 +# Description : +# +#================================================================ + +import numpy as np + +from utils.swc_handler import tree_to_voxels, parse_swc +from utils.math_utils import min_distances_between_two_sets + +class DistanceEvaluation(object): + def __init__(self, dsa_thr=2., resample1=True, resample2=True): + self.dsa_thr = dsa_thr + self.resample1 = resample1 + self.resample2 = resample2 + + def calc_dist(self, voxels1, voxels2): + ds = { + 'ESA': None, + 'DSA': None, + 'PDS': None, + } + + dists1, dists2 = min_distances_between_two_sets(voxels1, voxels2, reciprocal=True, return_index=False) + for key in ds.keys(): + if key == 'DSA': + dists1_ = dists1[dists1 > self.dsa_thr] + dists2_ = dists2[dists2 > self.dsa_thr] + if dists1_.shape[0] == 0: + dists1_ = np.array([0.]) + if dists2_.shape[0] == 0: + dists2_ = np.array([0.]) + elif key == 'PDS': + dists1_ = (dists1 > self.dsa_thr).astype(np.float32) + dists2_ = (dists2 > self.dsa_thr).astype(np.float32) + elif key == 'ESA': + dists1_ = dists1 + dists2_ = dists2 + ds[key] = dists1_.mean(), dists2_.mean(), (dists1_.sum() + dists2_.sum()) / (len(dists1) + len(dists2)) + ds = np.array(list(ds.values())) + return ds + + def run(self, reconfile, gsfile): + if type(reconfile) is str or isinstance(reconfile, os.PathLike): + tree1 = parse_swc(reconfile) + else: + tree1 = reconfile + if type(gsfile) is str or isinstance(reconfile, os.PathLike): + tree2 = parse_swc(gsfile) + else: + tree2 = gsfile + #print(f'#nodes for recon and gs: {len(tree1)}, {len(tree2)}') + + if self.resample1: + voxels1 = tree_to_voxels(tree1, crop_box=(10000,10000,10000)) + else: + voxels1 = np.array([node[2:5] for node in tree1]) + if self.resample2: + voxels2 = tree_to_voxels(tree2, crop_box=(10000,10000,10000)) + else: + voxels2 = np.array([node[2:5] for node in tree2]) + + if len(voxels1) == 0 or len(voxels2) == 0: + print(len(voxels1), len(voxels2)) + return None + + ds = self.calc_dist(voxels1, voxels2) + return ds + + +if __name__ == '__main__': + gsfile = '/home/lyf/Research/cloud_paper/micro_environ/benchmark/gs_crop/18452_4536_x11274_y21067.swc' + reconfile = '/home/lyf/Research/cloud_paper/micro_environ/benchmark/recon1891_weak/18452/5642_10537_2271.swc' + de = DistanceEvaluation() + ds = de.run(reconfile, gsfile) + print(ds) + + diff --git a/utils/swc_handler.py b/utils/swc_handler.py new file mode 100644 index 0000000..d5419cf --- /dev/null +++ b/utils/swc_handler.py @@ -0,0 +1,430 @@ +"""*================================================================ +* Copyright (C) 2021 Yufeng Liu (Braintell, Southeast University). All rights reserved. +* +* Filename : swc_handler.py +* Author : Yufeng Liu +* Date : 2021-03-15 +* Description : +* +================================================================*""" +import re +import numpy as np +from copy import deepcopy +from skimage.draw import line_nd + +NEURITE_TYPES = { + 'soma': [1], + 'axon': [2], + 'basal dendrite': [3], + 'apical dendrite': [4], + 'dendrite': [3,4], +} + + +def load_spacings(spacing_file, zxy_order=False): + """ + Load the spacing information for each brain. The spacing here refers to + the resolution along x,y,z axes. + """ + spacing_dict = {} + with open(spacing_file, 'r') as fp: + for line in fp.readlines(): + line = line.strip() + if not line: continue + ctxts = line.split(',') + brain_id = ctxts[0] + if not brain_id.isdigit(): + continue # the brain is encoded as digits + + brain_id = int(brain_id) + spacing = tuple(map(float, ctxts[1:])) + if zxy_order: + spacing = (spacing[2],spacing[0],spacing[1]) + spacing_dict[brain_id] = spacing + + return spacing_dict + + +def parse_swc(swc_file): + tree = [] + with open(swc_file) as fp: + for line in fp.readlines(): + line = line.strip() + if not line: continue + if line[0] == '#': continue + idx, type_, x, y, z, r, p = line.split()[:7] + idx = int(idx) + type_ = int(type_) + x = float(x) + y = float(y) + z = float(z) + r = float(r) + p = int(p) + tree.append((idx, type_, x, y, z, r, p)) + + return tree + + +def write_swc(tree, swc_file, header=tuple()): + if header is None: + header = [] + with open(swc_file, 'w') as fp: + for s in header: + if not s.startswith("#"): + s = "#" + s + if not s.endswith("\n") or not s.endswith("\r"): + s += "\n" + fp.write(s) + fp.write(f'##n type x y z r parent\n') + for leaf in tree: + idx, type_, x, y, z, r, p = leaf + fp.write(f'{idx:d} {type_:d} {x:.5f} {y:.5f} {z:.5f} {r:.1f} {p:d}\n') + + +def find_soma_node(tree, p_soma=-1, p_idx_in_leaf=6): + for leaf in tree: + if leaf[p_idx_in_leaf] == p_soma: + #print('Soma: ', leaf) + return leaf[0] + #raise ValueError("Could not find the soma node!") + return -99 + + +def find_soma_index(tree, p_soma=-1): + for i, leaf in enumerate(tree): + if leaf[6] == p_soma: + return i + #raise ValueError("find_soma_index: Could not find the somma node!") + return -99 + + +def get_child_dict(tree, p_idx_in_leaf=6): + child_dict = {} + for leaf in tree: + p_idx = leaf[p_idx_in_leaf] + if p_idx in child_dict: + child_dict[p_idx].append(leaf[0]) + else: + child_dict[p_idx] = [leaf[0]] + return child_dict + + +def get_index_dict(tree): + index_dict = {} + for i, leaf in enumerate(tree): + idx = leaf[0] + index_dict[idx] = i + return index_dict + + +def is_in_box(x, y, z, imgshape): + """ + imgshape must be in (z,y,x) order + """ + if x < 0 or y < 0 or z < 0 or \ + x > imgshape[2] - 1 or \ + y > imgshape[1] - 1 or \ + z > imgshape[0] - 1: + return False + return True + +def is_in_bbox(x, y, z, zyxzyx): + """ + zyxzyx is bbox in format of [(zmin, ymin, xmin), (zmax, ymax, xmax)] + """ + (zmin, ymin, xmin), (zmax, ymax, xmax) = zyxzyx + if x < xmin or y < ymin or z < zmin or \ + x > xmax or \ + y > ymax or \ + z > zmax: + return False + return True + +def prune(tree: list, ind_set: set): + """ + prune all nodes given by ind_set in morph + """ + child_dict = get_child_dict(tree) + index_dict = get_index_dict(tree) + tree = deepcopy(tree) + for i in ind_set: + q = [] + ind = index_dict[i] + if tree[ind] is None: + continue + tree[ind] = None + if i in child_dict: + q.extend(child_dict[i]) + while len(q) > 0: + head = q.pop(0) + ind = index_dict[head] + if tree[ind] is None: + continue + tree[ind] = None + if head in child_dict: + q.extend(child_dict[head]) + return [t for t in tree if t is not None] + + +def trim_swc(tree_orig, imgshape, keep_candidate_points=True, bfs=True): + """ + Trim the out-of-box and non_connecting leaves + """ + if bfs: + ib = set(t[0] for t in tree_orig if is_in_box(*t[2:5], imgshape)) + if keep_candidate_points: + child_dict = get_child_dict(tree_orig) + ib = ib.union(*(child_dict[i] for i in ib if i in child_dict)) + return prune(tree_orig, set(t[0] for t in tree_orig) - ib) + + def traverse_leaves(idx, child_dict, good_points, cand_pints, pos_dict): + leaf = pos_dict[idx] + p_idx, ib = leaf[-2:] + + if (p_idx in good_points) or (p_idx == -1): + if ib: + good_points.add(idx) # current node + else: + cand_points.add(idx) + return + + if idx not in child_dict: + return + + for new_idx in child_dict[idx]: + traverse_leaves(new_idx, child_dict, good_points, cand_pints, pos_dict) + + # execute trimming + pos_dict = {} + tree = deepcopy(tree_orig) + for i, leaf in enumerate(tree_orig): + idx, type_, x, y, z, r, p = leaf + leaf = (idx, type_, x, y, z, r, p, is_in_box(x,y,z,imgshape)) + pos_dict[idx] = leaf + tree[i] = leaf + + good_points = set() # points and all its upstream parents are in-box + cand_points = set() # all upstream parents are in-box, itself not + # initialize the visited set with soma, whose parent index is -1 + soma_idx = None + for leaf in tree: + if leaf[-2] == -1: + soma_idx = leaf[0] + break + #print(soma_idx) + + child_dict = {} + for leaf in tree: + if leaf[-2] in child_dict: + child_dict[leaf[-2]].append(leaf[0]) + else: + child_dict[leaf[-2]] = [leaf[0]] + # do DFS searching + #print(soma_idx) + traverse_leaves(soma_idx, child_dict, good_points, cand_points, pos_dict) + #print("#good/#cand/#total:", len(good_points), len(cand_points), len(pos_dict)) + + # return the tree, (NOTE: without order) + tree_trim = [] + if keep_candidate_points: + keep_points = good_points | cand_points + else: + keep_points = good_points + + for i, leaf in enumerate(tree): + idx = leaf[0] + if idx in keep_points: + tree_trim.append(leaf[:-1]) + + return tree_trim + + +def trim_out_of_box(tree_orig, imgshape, keep_candidate_points=True): + """ + Trim the out-of-box leaves + """ + # execute trimming + child_dict = {} + for leaf in tree_orig: + if leaf[-1] in child_dict: + child_dict[leaf[-1]].append(leaf[0]) + else: + child_dict[leaf[-1]] = [leaf[0]] + + pos_dict = {} + for i, leaf in enumerate(tree_orig): + pos_dict[leaf[0]] = leaf + + tree = [] + for i, leaf in enumerate(tree_orig): + idx, type_, x, y, z, r, p = leaf + ib = is_in_box(x,y,z,imgshape) + if ib: + tree.append(leaf) + elif keep_candidate_points: + if p in pos_dict and is_in_box(*pos_dict[p][2:5], imgshape): + tree.append(leaf) + elif idx in child_dict: + for ch_leaf in child_dict[idx]: + if is_in_box(*pos_dict[ch_leaf][2:5], imgshape): + tree.append(leaf) + break + return tree + + +def get_specific_neurite(tree, type_id): + if (not isinstance(type_id, list)) and (not isinstance(type_id, tuple)): + type_id = (type_id,) + + new_tree = [] + for leaf in tree: + if leaf[1] in type_id: + new_tree.append(leaf) + return new_tree + + +def shift_swc(swc_file, sx, sy, sz): + if type(swc_file) == list: + tree = swc_file + else: + tree = parse_swc(swc_file) + new_tree = [] + for node in tree: + idx, type_, x, y, z, r, p = node + x = x - sx + y = y - sy + z = z - sz + node = (idx, type_, x, y, z, r, p) + new_tree.append(node) + return new_tree + + +def scale_swc(swc_file, scale): + if type(swc_file) == list: + tree = swc_file + else: + tree = parse_swc(swc_file) + if isinstance(scale, (int, float)): + scale_x, scale_y, scale_z = scale, scale, scale + elif isinstance(scale, tuple) or isinstance(scale, list): + scale_x, scale_y, scale_z = scale + else: + raise NotImplementedError(f"Type of parameter scale {type(scale)} is not supported!") + + new_tree = [] + for node in tree: + idx, type_, x, y, z, r, p = node + x *= scale_x + y *= scale_y + z *= scale_z + node = (idx, type_, x, y, z, r, p) + new_tree.append(node) + return new_tree + +def flip_swc(swc_file, axis='y', dim=None): + if type(swc_file) == list: + tree = swc_file + else: + tree = parse_swc(swc_file) + + new_tree = [] + for node in tree: + idx, type_, x, y, z, r, p = node + if axis == 'x': + x = dim - x + elif axis == 'y': + y = dim - y + elif axis == 'z': + z = dim - z + node = (idx, type_, x, y, z, r, p) + new_tree.append(node) + return new_tree + +def crop_tree_by_bbox(morph, bbox, keep_candidate_points=True): + """ + Crop swc by trim all nodes out-of-bbox. This function differs from `trim_out_of_box` it does + not assume center cropping + """ + if isinstance(morph, list): + mtree = morph + else: + mtree = morph.tree + + tree = [] + for i, leaf in enumerate(morph.tree): + idx, type_, x, y, z, r, p = leaf[:7] + ib = is_in_bbox(x,y,z,bbox) + if ib: + tree.append(leaf) + if keep_candidate_points and (idx in morph.child_dict): + for ch_leaf in morph.child_dict[idx]: + if not is_in_bbox(*morph.pos_dict[ch_leaf][2:5], bbox): + tree.append(morph.pos_dict[ch_leaf]) + return tree + + +def tree_to_voxels(tree, crop_box): + # crop_box in (z,y,x) order + # initialize position dict + pos_dict = {} + new_tree = [] + for i, leaf in enumerate(tree): + idx, type_, x, y, z, r, p = leaf + leaf_new = (*leaf, is_in_box(x,y,z,crop_box)) + pos_dict[leaf[0]] = leaf_new + new_tree.append(leaf_new) + tree = new_tree + + xl, yl, zl = [], [], [] + for _, leaf in pos_dict.items(): + idx, type_, x, y, z, r, p, ib = leaf + if p == -1: continue # soma + + if p not in pos_dict: + continue + + parent_leaf = pos_dict[p] + if (not ib) and (not parent_leaf[ib]): + print('All points are out of box! do trim_swc before!') + raise ValueError + + # draw line connecting each pair + cur_pos = leaf[2:5] + par_pos = parent_leaf[2:5] + lin = line_nd(cur_pos[::-1], par_pos[::-1], endpoint=True) + xl.extend(list(lin[2])) + yl.extend(list(lin[1])) + zl.extend(list(lin[0])) + + voxels = [] + for (xi,yi,zi) in zip(xl,yl,zl): + if is_in_box(xi,yi,zi,crop_box): + voxels.append((xi,yi,zi)) + # remove duplicate points + voxels = np.array(list(set(voxels)), dtype=np.float32) + return voxels + + +def rm_disconnected(tree: list, anchor: int): + roots = [t[0] for t in tree if t[6] == -1] + ch = get_child_dict(tree) + idx = get_index_dict(tree) + flag = np.zeros(len(tree), dtype=int) + for r in roots: + q = [r] + while len(q) > 0: + head = q.pop(0) + flag[idx[head]] = r + if head in ch: + q.extend(ch[head]) + ind = flag[idx[anchor]] + return prune(tree, set(t[0] for t, f in zip(tree, flag) if f != ind)) + +def get_soma_from_swc(swcfile): + # fast parse swc information + # only for swc, not eswc + with open(swcfile) as fp: + soma_str = re.search('.* -1\n', fp.read()).group() + soma = soma_str.split() + return soma +