Skip to content

Commit

Permalink
Started adding tests and simplifying composite figure 1. Note done yet.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonleandergrimm committed Mar 7, 2024
1 parent cfd696c commit 8a5738f
Showing 1 changed file with 58 additions and 50 deletions.
108 changes: 58 additions & 50 deletions figures/composite_fig_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def order_df(

na_type_order = ["DNA", "RNA"]

df = df.sort_values(
df = df.sort_values( #FIXME -> check again what this function is doing
by="na_type",
key=lambda col: col.map({k: i for i, k in enumerate(na_type_order)}),
)
Expand Down Expand Up @@ -88,63 +88,71 @@ def shape_boxplot_df(boxplot_df: pd.DataFrame) -> pd.DataFrame:
return boxplot_df


def shape_barplot_df(barplot_df: pd.DataFrame) -> pd.DataFrame:
def shape_barplot_df(df: pd.DataFrame) -> pd.DataFrame:
taxid_parents = load_taxonomic_data()

unique_numeric_cols = set(
col for col in barplot_df.columns if isinstance(col, int)
)

species_family_mapping = {
col: get_family(col, taxid_parents) for col in unique_numeric_cols
}

barplot_df.rename(columns=species_family_mapping, inplace=True)

melted_df = barplot_df.melt(
species_taxids = df.columns[2:]
species_family_mapping = {}
for species_taxid in species_taxids:
family_taxid = get_family(species_taxid, taxid_parents)
species_family_mapping[species_taxid] = family_taxid

df.rename(columns=species_family_mapping, inplace=True)
df = df.melt(
id_vars=["study", "sample"],
value_vars=[
col
for col in barplot_df.columns
if col not in ["study", "sample"] and not pd.isna(col)
],
var_name="family_taxid",
value_name="read_count",
)

grouped_df = melted_df.groupby(["study", "family_taxid"]).read_count.sum()

grouped_df_w_o_zeroes = grouped_df[grouped_df != 0]

df_normalized = grouped_df_w_o_zeroes.unstack(level=-1).fillna(0)

df_normalized = df_normalized.div(df_normalized.sum(axis=1), axis=0)

df = df.groupby(["study", "family_taxid"]).read_count.sum().reset_index()
df = df[df !=0]
df["relative_abundance"] = df.groupby(["study"])["read_count"].transform(lambda x: x / x.sum())
N_BIGGEST_FAMILIES = 9

family_mean_across_studies = df_normalized.apply(
lambda x: np.mean(x), axis=0

print(df)
top_taxa = (
df.groupby("family_taxid").relative_abundance.sum().nlargest(N_BIGGEST_FAMILIES).index
)

top_families = family_mean_across_studies.nlargest(
N_BIGGEST_FAMILIES
).index

df_normalized["Other Viral Families"] = df_normalized.loc[
:, ~df_normalized.columns.isin(top_families)
].sum(axis=1)

df_normalized = df_normalized.loc[
:, list(top_families) + ["Other Viral Families"]
]
dict_family_name = {}

for taxid in top_families.tolist():
dict_family_name[taxid] = get_taxid_name(taxid, taxonomic_names)

barplot_df = df_normalized.rename(columns=dict_family_name)
top_taxa_rows = df[df.family_taxid.isin(top_taxa)]
top_taxa_rows["hv_family"] = top_taxa_rows["family_taxid"].apply(
lambda x: get_taxid_name(x, taxonomic_names)
)

return barplot_df
minor_taxa = df[~df.family_taxid.isin(top_taxa)]["family_taxid"].unique() #FIXME
print(minor_taxa)
minor_taxa_rows = (
df[df.family_taxid.isin(minor_taxa)]
.groupby(["study"])
.agg(
{
"relative_abundance": "sum",
}
)
).reset_index()
minor_taxa_rows["hv_family"] = "Other Viral Families"
#family_mean_across_studies = df_normalized.apply(
# lambda x: np.mean(x), axis=0
#)

#top_families = family_mean_across_studies.nlargest(
# N_BIGGEST_FAMILIES
#).index

#df_normalized["Other Viral Families"] = df_normalized.loc[
# :, ~df_normalized.columns.isin(top_families)
#].sum(axis=1)

#df_normalized = df_normalized.loc[
# :, list(top_families) + ["Other Viral Families"]
#]
#dict_family_name = {}

#for taxid in top_families.tolist():
# dict_family_name[taxid] = get_taxid_name(taxid, taxonomic_names)

#df = df_normalized.rename(columns=dict_family_name)
print(df)
return df


def get_study_nucleic_acid_mapping() -> dict[str, str]:
Expand Down Expand Up @@ -456,7 +464,7 @@ def barplot(
"#bc80bd",
"#d9d9d9",
]

df = barplot_df.
barplot_df.set_index("study", inplace=True)

barplot_df.loc[study_order].plot(
Expand Down Expand Up @@ -501,7 +509,7 @@ def barplot(


def save_plot(fig, figdir: Path, name: str) -> None:
for ext in ["pdf", "png"]:
for ext in ["svg", "png"]:
fig.savefig(figdir / f"{name}.{ext}", bbox_inches="tight", dpi=900)


Expand Down

0 comments on commit 8a5738f

Please sign in to comment.