Skip to content

Commit

Permalink
chore(dev): fix taxid species name description annotation blast parser
Browse files Browse the repository at this point in the history
  • Loading branch information
esteinig committed Feb 21, 2025
1 parent 3dd96f9 commit 37c998b
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 26 deletions.
11 changes: 7 additions & 4 deletions cerebro/stack/pipe/src/modules/pathogen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,15 +349,18 @@ impl PathogenDetection {
if let Some(records) = &output.assembly.blast {

for record in records {
let taxid = record.taxid.trim().to_string();
let mut taxid = record.taxid.trim().to_string();
let mut name = record.taxname.trim().to_string();

// Custom BLAST databases don't usually have the 'ssciname' associated
// we could use a taxonomy to add this later, but here we use the
// 'stitle' as taxname as configured in 'Cipher'
// 'stitle' as taxname as configured in 'Cipher' with sequence description
// in the format: {taxid}:::{name}

if name == "N/A".to_string() {
name = record.title.trim().to_string()
if name == "N/A".to_string() || taxid == "N/A".to_string() {
let description = record.title.trim().split(":::").collect::<Vec<_>>();
taxid = description.first().ok_or(WorkflowError::PathogenTaxidAnnotationMissing)?.to_string();
name = description.last().ok_or(WorkflowError::PathogenTaxnameAnnotationMissing)?.to_string();
}

let rank = PathogenDetectionRank::from_str(&record.taxrank);
Expand Down
75 changes: 75 additions & 0 deletions utils/ercc/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def plot_qc_overview(
# Merge metadata with QC table
merged_data = qc.merge(metadata, on="id", how="left")

# Subset to experiment
merged_data = merged_data[merged_data["experiment"] == experiment]

# Remove the repeat identifier from the sequencing library
Expand All @@ -165,6 +166,80 @@ def plot_qc_overview(
plot_grouped_qc(merged_data=merged_data, column=column, log_scale=log_scale, title=title, output=output, hue="label")


@app.command()
def plot_pools_qubit_reads(
qc_reads: Path = typer.Option(..., help="Quality control summary table"),
metadata: Path = typer.Option(..., help="Experiment metadata table"),
experiment: str = typer.Option("pool", help="Subset by metadata column for experiment"),
):
"""Pooling quality control with attention to Qubit values """


qc = pd.read_csv(qc_reads, sep="\t", header=0)
metadata = pd.read_csv(metadata, sep="\t", header=0)

# Remove the sample identifier from the sequencing library
qc["id"] = qc["id"].str.replace(r"(_[^_]*)$", "", regex=True)

# Merge metadata with QC table
merged_data = qc.merge(metadata, on="id", how="left")

# Subset to experiment
merged_data = merged_data[merged_data["experiment"] == experiment]

# Remove the repeat identifier from the sequencing library
merged_data["id"] = merged_data["id"].str.replace(r"__P[0-9]+", "", regex=True)
merged_data["id"] = merged_data["id"].str.replace(r"__RPT[0-9]+", "", regex=True)


# Convert "host" to a numeric column if it is continuous
merged_data["host_qubit"] = pd.to_numeric(merged_data["host_qubit"], errors="coerce")

print(merged_data)

fig, axes = plt.subplots(
nrows=1,
ncols=2,
figsize=(20, 12)
)

for i, nucleic_acid in enumerate(("dna", "rna")):

# Ensure a reasonable number of unique markers
if merged_data["host_spike"].nunique() > 10:
raise ValueError(f"Too many unique marker values")

data = merged_data[merged_data["nucleic_acid"] == nucleic_acid]

ax = axes[i]

sns.scatterplot(
x="host_qubit", y=f"input_reads", hue="label", style="host_spike",
data=data, hue_order=["P1", "P2"],
ax=ax, palette="deep", edgecolor="black"
)

sns.regplot(
x="host_qubit", y=f"input_reads",
data=data, scatter=False, ax=ax,
color="black", line_kws={"linestyle": "dashed"}
)

ax.set_xlabel("Library Qubit (ng/ul)")
ax.set_ylabel(f"Input reads (n)\n")

ax.set_title(f"{nucleic_acid.upper()}")
ax.set_ylim(0)

legend = ax.get_legend()
if legend:
legend.set_title(None)

fig.suptitle(f"Input reads vs. library concentration\n", fontsize=18)
fig.tight_layout()
fig.savefig(f"input_reads_qubit_correlation.png", dpi=300, transparent=False)


@app.command()
def plot_internal_controls_overview(
qc_controls: Path = typer.Option(..., help="Quality control module table controls"),
Expand Down
216 changes: 194 additions & 22 deletions utils/taxa/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,17 @@ def plot_pools(
metadata: Path = typer.Option(
..., help="Reference metadata table"
),
plot_species: str = typer.Option(
..., help="List of species names to plot, comma-separated"
),
experiment: str = typer.Option(
"pool", help="Experiment column subset"
),
output: str = typer.Option(
"pools.png", help="Plot output"
),
plot_species: str = typer.Option(
..., help="List of species names to plot, comma-separated"
qubit: bool = typer.Option(
False, help="Host values are not categorical but continuous measurements from Qubit"
),
):
"""
Expand All @@ -247,38 +250,75 @@ def plot_pools(

species = species.merge(metadata, on="id", how="left")

print(species)
for nucleic_acid in ("dna", "rna"):

fig, axes = plt.subplots(
nrows=len(classifiers),
ncols=len(plot_species),
figsize=(6 * len(plot_species), 20)
)

for nucleic_acid in ("dna", "rna"):
fig, axes = plt.subplots(nrows=len(classifiers), ncols=len(plot_species), figsize=(6 * len(plot_species), 20))
species_nucleic_acid = species[species["nucleic_acid"] == nucleic_acid]
species_nucleic_acid = species_nucleic_acid[species_nucleic_acid["experiment"] == experiment]

all_hosts = species_nucleic_acid["host"].unique()
for col_index, species_name in enumerate(plot_species):
species_data = species_nucleic_acid[species_nucleic_acid["name"] == species_name]

for i, classifier in enumerate(classifiers):
ax = axes[i][col_index]

sns.barplot(
x="host", y=f"{classifier}_rpm", hue="label",
data=species_data, hue_order=["P1", "P2"],
ax=ax, palette=YESTERDAY_MEDIUM,
order=all_hosts # Add this to ensure all hosts are shown
)

sns.stripplot(
x="host", y=f"{classifier}_rpm", hue="label",
data=species_data, hue_order=["P1", "P2"],
ax=ax, palette=YESTERDAY_MEDIUM, dodge=True,
edgecolor="black", linewidth=2, legend=None,
order=all_hosts # Add this to ensure all hosts are shown
)

if qubit:

# Convert "host" to a numeric column if it is continuous
species_data["host_qubit"] = pandas.to_numeric(species_data["host_qubit"], errors="coerce")

print(f"Species: {species_name} Classifier: {classifier}")
print(species_data)
print(f"=============================================================")

# Scatterplot for correlation
sns.scatterplot(
x="host_qubit", y=f"{classifier}_rpm", hue="label",
data=species_data, hue_order=["P1", "P2"],
ax=ax, palette="deep", edgecolor="black"
)

# Get colors from the scatterplot palette
palette = dict(zip(["P1", "P2"], sns.color_palette("deep", 2)))

# Separate regression lines for P1 and P2
for label in ["P1", "P2"]:
subset = species_data[species_data["label"] == label]
sns.regplot(
x="host_qubit", y=f"{classifier}_rpm",
data=subset, scatter=False, ax=ax,
color=palette[label], line_kws={"linewidth": 2, "alpha": 0.8}
)


ax.set_xlabel("Library Qubit (ng/ul)")
else:

all_hosts = species_nucleic_acid["host_spike"].unique()

sns.barplot(
x="host_spike", y=f"{classifier}_rpm", hue="label",
data=species_data, hue_order=["P1", "P2"],
ax=ax, palette=YESTERDAY_MEDIUM,
order=all_hosts # Add this to ensure all hosts are shown
)

sns.stripplot(
x="host_spike", y=f"{classifier}_rpm", hue="label",
data=species_data, hue_order=["P1", "P2"],
ax=ax, palette=YESTERDAY_MEDIUM, dodge=True,
edgecolor="black", linewidth=2, legend=None,
order=all_hosts # Add this to ensure all hosts are shown
)

ax.set_xlabel("\n")

ax.set_title(f"\n{species_name} ({classifier.capitalize()})")
ax.set_xlabel("\n")
ax.set_ylabel(f"{classifier.capitalize()} RPM\n")
ax.set_ylim(0)

Expand All @@ -291,6 +331,138 @@ def plot_pools(
fig.savefig(f"{nucleic_acid}_{output}", dpi=300, transparent=False)


@app.command()
def plot_targets(
species: Path = typer.Argument(
..., help="Pathogen detection table for species"
),
output: str = typer.Option(
"taxa_detection.png", help="Plot taxon detection output"
),
log_scale: bool = typer.Option(False, help="Log scale for plot"),
ids: str = typer.Option(None, help="Sample identifier start strings to subset dataset"),
plot_species: str = typer.Option(
None, help="List of species names to plot, comma-separated"
),
plot_labels: str = typer.Option(
None, help="List of species labels to plot, comma-separated"
),
exclude_species: str = typer.Option(
None, help="List of species to exclud, comma-separated"
)
):
"""
Simple taxa detection plot across classifiers
"""

classifiers = ["kraken", "bracken", "metabuli", "ganon", "vircov"]

species = pandas.read_csv(species, sep="\t", header=0)

# Remove the sample identifier from the sequencing library
species["id"] = species["id"].str.replace(r"(_[^_]*)$", "", regex=True)

if ids:
ids = tuple([s.strip() for s in ids.split(",")])
species = species[species["id"].str.startswith(ids)]

if plot_species:
plot_species = [s.strip() for s in plot_species.split(",")]
else:
plot_species = [
'Aspergillus niger',
'Cryptococcus neoformans',
"Haemophilus influenzae",
"Mycobacterium tuberculosis",
"Streptococcus pneumoniae",
"Toxoplasma gondii",
'Simplexvirus humanalpha1',
'Orthoflavivirus murrayense',
]

if plot_labels:
plot_labels = [s.strip() for s in plot_labels.split(",")]
else:
plot_labels = [
'ANIG',
'CNEO',
"HINF",
"MTB",
"SPNEUMO",
"TOXO",
'HSV-1',
'MVEV',
]

if len(plot_labels) != len(plot_species):
raise ValueError("Label and species designations are not of equal length")

if exclude_species:
exclude_species = [s.strip() for s in exclude_species.split(",")]

new_plot_species = []
exclude_indices = []
for i, sp in enumerate(plot_species):
if sp not in exclude_species:
new_plot_species.append(sp)
else:
exclude_indices.append(i)

plot_species = new_plot_species.copy()
plot_labels = [l for i, l in enumerate(plot_labels) if i not in exclude_indices]

print(plot_species, plot_labels)

fig, axes = plt.subplots(
nrows=len(classifiers),
ncols=2,
figsize=(6 * 2, 20)
)

print(species)

for ni, nucleic_acid in enumerate(("DNA", "RNA")):

species_nucleic_acid = species[species["id"].str.contains(nucleic_acid)]
species_data = species_nucleic_acid[species_nucleic_acid["name"].isin(plot_species)]

for i, classifier in enumerate(classifiers):
ax = axes[i][ni]

if log_scale:
species_data[f"{classifier}_rpm"] = np.log10(species_data[f"{classifier}_rpm"])

sns.barplot(
x="name", y=f"{classifier}_rpm", hue=None,
data=species_data, hue_order=None,
ax=ax, palette=YESTERDAY_MEDIUM,
order=plot_species # Add this to ensure all species are shown
)

sns.stripplot(
x="name", y=f"{classifier}_rpm", hue=None,
data=species_data, hue_order=None,
ax=ax, palette=YESTERDAY_MEDIUM, dodge=True,
edgecolor="black", linewidth=2, legend=None,
order=plot_species # Add this to ensure all species are shown
)

ax.set_title(f"\n{classifier.capitalize()}")
ax.set_xlabel("\n")
ax.set_ylabel(f"{classifier.capitalize()} RPM\n")
ax.set_ylim(0)

ax.set_xticklabels(plot_labels, rotation=45, ha="right")

legend = ax.get_legend()
if legend:
legend.set_title(None)

fig.suptitle(f"Target species (LOD)\n", fontsize=18)
fig.tight_layout()
fig.savefig(output, dpi=300, transparent=False)


@app.command()
def plot_simulation_evaluation(
summaries: List[Path] = typer.Argument(
Expand Down

0 comments on commit 37c998b

Please sign in to comment.