Skip to content

Commit

Permalink
Fix: More fixing (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon authored Feb 16, 2024
1 parent 91e1cf3 commit 173a0aa
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 41 deletions.
79 changes: 39 additions & 40 deletions emodel_generalisation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def cli(verbose, no_progress):

@cli.command("compute_currents")
@click.option("--input-path", type=click.Path(exists=True), required=True)
@click.option("--population_name", type=click.Path(exists=True), default=None)
@click.option("--population-name", type=str, default=None)
@click.option("--output-path", default="circuit_currents.h5", type=str)
@click.option("--morphology-path", type=click.Path(exists=True), required=False)
@click.option("--hoc-path", type=str, required=True)
Expand Down Expand Up @@ -184,30 +184,15 @@ def compute_currents(
only_rin=only_rin,
)

cols = ["resting_potential", "input_resistance", "exception"]
if not only_rin:
cols += ["holding_current", "threshold_current"]

# we populate the full circuit with duplicates if any
if len(cells_df) == len(unique_cells_df):
cells_df = unique_cells_df
else:
unique_cells_df = unique_cells_df.set_index(["morphology", "emodel"])
for entry, data in tqdm(
cells_df.groupby(["morphology", "emodel"]), disable=os.environ.get("NO_PROGRESS", False)
):
for col in cols:
cells_df.loc[data.index, col] = unique_cells_df.loc[entry, col]

failed_cells = cells_df[
cells_df["input_resistance"].isna() | (cells_df["input_resistance"] < 0)
failed_cells = unique_cells_df[
unique_cells_df["input_resistance"].isna() | (unique_cells_df["input_resistance"] < 0)
].index
if len(failed_cells) > 0:
L.info("%s failed cells, we retry with fixed timesteps:", len(failed_cells))
L.info(cells_df.loc[failed_cells])
protocol_config["deterministic"] = False
cells_df.loc[failed_cells] = evaluate_currents(
cells_df.loc[failed_cells],
unique_cells_df.loc[failed_cells] = evaluate_currents(
unique_cells_df.loc[failed_cells],
protocol_config,
hoc_path,
parallel_factory=parallel_factory,
Expand All @@ -216,13 +201,28 @@ def compute_currents(
only_rin=only_rin,
)

failed_cells = cells_df[
cells_df["input_resistance"].isna() | (cells_df["input_resistance"] < 0)
failed_cells = unique_cells_df[
unique_cells_df["input_resistance"].isna() | (unique_cells_df["input_resistance"] < 0)
].index
if len(failed_cells) > 0:
L.info("still %s failed cells (we drop):", len(failed_cells))
L.info(cells_df.loc[failed_cells])
cells_df.loc[failed_cells, "mtype"] = None
L.info(unique_cells_df.loc[failed_cells])
unique_cells_df.loc[failed_cells, "mtype"] = None

cols = ["resting_potential", "input_resistance", "exception"]
if not only_rin:
cols += ["holding_current", "threshold_current"]

# we populate the full circuit with duplicates if any
if len(cells_df) == len(unique_cells_df):
cells_df = unique_cells_df
else:
unique_cells_df = unique_cells_df.set_index(["morphology", "emodel"])
for entry, data in tqdm(
cells_df.groupby(["morphology", "emodel"]), disable=os.environ.get("NO_PROGRESS", False)
):
for col in cols:
cells_df.loc[data.index, col] = unique_cells_df.loc[entry, col]

cols_rename = {col: f"@dynamics:{col}" for col in cols if col != "exception"}

Expand All @@ -249,19 +249,10 @@ def plot_evaluation(cells_df, access_point, main_path="analysis_plot", clip=5, f
"""Make some plots of evaluations."""
main_path = Path(main_path)
main_path.mkdir(exist_ok=True)
if feature_filter is None or feature_filter == "":
feature_filter = FEATURE_FILTER
feature_filter.append("inv_time_to_first_spike")
feature_filter.append("burst_number")
feature_filter.append("ohmic_input_resistance_vb_ssse")
feature_filter.append("AHP_depth_abs")
feature_filter.append("AHP_depth")
else:
feature_filter = json.loads(feature_filter)

L.info("Plotting summary figure...")
scores = get_score_df(cells_df, filters=feature_filter)
cells_df["cost"] = np.clip(scores.max(1), 0, clip)
cells_df["cost"] = np.clip(scores.max(axis=1), 0, clip)
_df = cells_df[["emodel", "mtype", "cost"]].groupby(["emodel", "mtype"]).mean().reset_index()
plot_df = _df.pivot(index="emodel", columns="mtype", values="cost")
plt.figure(figsize=(10, 6))
Expand Down Expand Up @@ -370,10 +361,6 @@ def evaluate(
config_path, final_path, legacy, local_config=local_config_path
)
cells_df, _ = _load_circuit(input_path, morphology_path, population_name)
# cells_df = cells_df[cells_df.emodel == "bAC_L6BTC"]
# cells_df = cells_df[cells_df.mtype == "L23_LBC"].reset_index(drop=True)
# cells_df["@dynamics:AIS_scaler"] = 4.0
# cells_df["@dynamics:soma_scaler"] = 2.0

if n_cells_per_emodel is not None:
cells_df = (
Expand Down Expand Up @@ -443,6 +430,18 @@ def evaluate(

exemplar_df = exemplar_df.set_index("emodel").loc[cells_df.emodel].reset_index()

if feature_filter is None or feature_filter == "":
feature_filter = FEATURE_FILTER
feature_filter.append("inv_time_to_first_spike")
feature_filter.append("burst_number")
feature_filter.append("ohmic_input_resistance_vb_ssse")
feature_filter.append("AHP_depth_abs")
feature_filter.append("AHP_depth")
feature_filter.append("RMPProtocol")
feature_filter.append("SearchHoldingCurrent")
else:
feature_filter = json.loads(feature_filter)

pass_dfs = []
Path(validation_path).mkdir(parents=True, exist_ok=True)
with PdfPages(Path(validation_path) / "mm_features.pdf") as pdf:
Expand All @@ -458,7 +457,7 @@ def evaluate(
_pass[col] = cells_score_df[col] <= np.maximum(
5.0, 5.0 * exemplar_score_df[col].to_list()[0]
)
_pass = 1 - _pass.copy() # copy to make it less fragmented
_pass = _pass.copy() # copy to make it less fragmented
_pass["mtype"] = cells_df.mtype

data = _pass.groupby("mtype").mean()
Expand All @@ -471,7 +470,7 @@ def evaluate(
vmin=0.0,
vmax=1.0,
ax=ax,
cbar_kws={"label": "Fraction of failed features", "shrink": 0.5},
cbar_kws={"label": "Fraction of pass features", "shrink": 0.5},
)

ax.set_xlabel("mtype")
Expand Down
2 changes: 1 addition & 1 deletion emodel_generalisation/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# or send a letter to Creative Commons, 171
# Second Street, Suite 300, San Francisco, California, 94105, USA.

VERSION = "0.2.7" # pragma: no cover
VERSION = "0.2.8.dev0" # pragma: no cover
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"sqlalchemy-utils>=0.37.2",
"bluecellulab>=1.7.6",
"voxcell>=3.1.5",
"efel>=5.5.5",
]

doc_reqs = [
Expand Down

0 comments on commit 173a0aa

Please sign in to comment.