Skip to content

Commit

Permalink
Fixed a bug with composite_fig_1.py where some typing changes broke t…
Browse files Browse the repository at this point in the history
…he get_family code. I'm not sure why this happened, but reverting to the old code fixed it...
  • Loading branch information
simonleandergrimm committed Feb 1, 2024
1 parent 42e04ee commit f58a217
Showing 1 changed file with 27 additions and 29 deletions.
56 changes: 27 additions & 29 deletions figures/composite_fig_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import os
import subprocess
import typing
from pathlib import Path

import matplotlib.pyplot as plt # type: ignore
Expand All @@ -14,9 +13,6 @@
import seaborn as sns # type: ignore
from matplotlib.gridspec import GridSpec # type: ignore

if os.path.basename(os.getcwd()) != "figures":
raise RuntimeError("Run this script from figures/")

dashboard = os.path.expanduser("~/code/mgs-pipeline/dashboard/")


Expand Down Expand Up @@ -100,7 +96,7 @@ def shape_barplot_df(barplot_df: pd.DataFrame) -> pd.DataFrame:
)

species_family_mapping = {
col: get_family(int(col), taxid_parents) for col in unique_numeric_cols
col: get_family(col, taxid_parents) for col in unique_numeric_cols
}

barplot_df.rename(columns=species_family_mapping, inplace=True)
Expand All @@ -126,16 +122,16 @@ def shape_barplot_df(barplot_df: pd.DataFrame) -> pd.DataFrame:

N_BIGGEST_FAMILIES = 9

family_mean_across_studies = df_normalized.apply( # type: ignore
family_mean_across_studies = df_normalized.apply(
lambda x: np.mean(x), axis=0
)

top_families = family_mean_across_studies.nlargest( # type: ignore
top_families = family_mean_across_studies.nlargest(
N_BIGGEST_FAMILIES
).index

df_normalized["Other Viral Families"] = df_normalized.loc[
:, ~df_normalized.columns.isin(top_families) # type: ignore
:, ~df_normalized.columns.isin(top_families)
].sum(axis=1)

df_normalized = df_normalized.loc[
Expand Down Expand Up @@ -164,22 +160,20 @@ def get_study_nucleic_acid_mapping() -> dict[str, str]:
return study_nucleic_acid_mapping


@typing.no_type_check
def load_taxonomic_data() -> dict[int, tuple[str, int]]:
parents: dict[int, tuple[str, int]] = {}
parents = {}
with open(os.path.join(dashboard, "nodes.dmp")) as inf:
for line in inf:
child_taxid, parent_taxid, child_rank, *_ = line.replace(
"\t|\n", ""
).split("\t|\t")
parent_taxid = int(parent_taxid)
child_taxid = int(child_taxid)
child_rank = str(child_rank.strip())
child_rank = child_rank.strip()
parents[child_taxid] = (child_rank, parent_taxid)
return parents


@typing.no_type_check
def get_family(taxid: int, parents: dict[int, tuple[str, int]]) -> int:
iteration_count = 0

Expand All @@ -197,22 +191,22 @@ def get_family(taxid: int, parents: dict[int, tuple[str, int]]) -> int:
if iteration_count > 100:
break
else:
family_taxid = int(current_taxid)
family_taxid = current_taxid
return family_taxid


@typing.no_type_check
def get_taxid_name(
target_taxid: int, taxonomic_names: dict[str, list[str]]
) -> str:
tax_name = str(taxonomic_names[f"{target_taxid}"][0])
tax_name = taxonomic_names[f"{target_taxid}"][0]
return tax_name


def assemble_plotting_dfs() -> tuple[pd.DataFrame, pd.DataFrame]:
box_plot_data = []
bar_plot_data = []
for study in studies:
# Dropping studies that aren't WTP based
if study not in [
"Bengtsson-Palme 2016",
"Munk 2022",
Expand Down Expand Up @@ -260,36 +254,31 @@ def assemble_plotting_dfs() -> tuple[pd.DataFrame, pd.DataFrame]:
humanreads = "%s.humanviruses.tsv" % sample

if not os.path.exists(f"../humanviruses/{humanreads}"):
print(
"Downloading %s from %s" % (humanreads, bioproject),
flush=True,
)
subprocess.check_call(
[
"aws",
"s3",
"cp",
"s3://nao-mgs/%s/humanviruses/%s"
% (bioproject, humanreads),
"../humanviruses/",
"humanviruses/",
]
)

with open(f"../humanviruses/{humanreads}") as inf:
human_virus_counts = {}
human_virus_reads = 0

for line in inf:
(
line_taxid,
clade_assignments,
_,
) = line.strip().split("\t")
clade_hits = int(clade_assignments)
line_taxid = line_taxid
line_taxid = int(line_taxid)

human_virus_counts[line_taxid] = clade_hits
human_virus_reads += clade_hits
human_virus_reads += int(clade_hits)

human_virus_relative_abundance = (
human_virus_reads / metadata_samples[sample]["reads"]
Expand All @@ -312,12 +301,10 @@ def assemble_plotting_dfs() -> tuple[pd.DataFrame, pd.DataFrame]:
"cp",
"s3://nao-mgs/%s/cladecounts/%s"
% (bioproject, cladecounts),
"../cladecounts/",
"cladecounts/",
]
)
with gzip.open(
f"../cladecounts/{cladecounts}", mode="rt"
) as inf:
with gzip.open(f"../cladecounts/{cladecounts}") as inf:
taxa_abundances = {
"DNA Viruses": 0,
"RNA Viruses": 0,
Expand Down Expand Up @@ -366,6 +353,15 @@ def assemble_plotting_dfs() -> tuple[pd.DataFrame, pd.DataFrame]:
return boxplot_df, barplot_df


def return_study_order(boxplot_df: pd.DataFrame) -> list[str]:
study_nucleic_acid_mapping = get_study_nucleic_acid_mapping()
df["na_type"] = df["study"].map(study_nucleic_acid_mapping)
order = (
df[df["na_type"] == "DNA"]["study"].unique()
+ df[df["na_type"] == "RNA"]["study"].unique()
)


def boxplot(
ax: plt.Axes,
boxplot_df: pd.DataFrame,
Expand Down Expand Up @@ -404,7 +400,7 @@ def boxplot(
ax.set_ylabel("")
ax.tick_params(left=False, labelright=True, labelleft=False)
for label in ax.get_yticklabels():
label.set_ha("left") # type: ignore
label.set_ha("left")

ax.yaxis.set_label_position("right")
formatter = ticker.FuncFormatter(
Expand All @@ -427,6 +423,7 @@ def boxplot(
fontsize=10,
frameon=False,
)
# change x labels to log scale (8 -> 10^8)

for i in range(-7, 0):
ax.axvline(i, color="grey", linewidth=0.3, linestyle=":")
Expand Down Expand Up @@ -482,7 +479,7 @@ def barplot(
ax.set_ylabel("")
ax.tick_params(left=False, labelright=True, labelleft=False)
for label in ax.get_yticklabels():
label.set_ha("left") # type: ignore
label.set_ha("left")

ax.axhline(5.5, color="black", linewidth=1, linestyle="-")

Expand Down Expand Up @@ -536,3 +533,4 @@ def start():

if __name__ == "__main__":
start()

0 comments on commit f58a217

Please sign in to comment.