Skip to content

Commit

Permalink
Ported over mgs.py to use new data by Will.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonleandergrimm committed May 28, 2024
1 parent 6ffee77 commit 8a85117
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 68 deletions.
4 changes: 1 addition & 3 deletions fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ def summarize_output(coeffs: pd.DataFrame) -> pd.DataFrame:


def start(num_samples: int, plot: bool) -> None:
branch = "simon-p2ra-manuscript"
print("Using mgs-pipeline branch simon-p2ra-manuscript")
figdir = Path("fig")
if plot:
figdir.mkdir(exist_ok=True)
mgs_data = MGSData.from_repo(ref=branch)
mgs_data = MGSData.from_repo()
input_data = []
output_data = []
for (
Expand Down
219 changes: 154 additions & 65 deletions mgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,30 @@
from datetime import date
from enum import Enum
from typing import NewType, Optional
from collections import defaultdict
import os
import csv

from pydantic import BaseModel

from pathogen_properties import TaxID
from tree import Tree

MGS_REPO_DEFAULTS = {
"user": "naobservatory",
"repo": "mgs-pipeline",
"ref": "data-2023-07-21",
}
BIOPROJECTS_DIR = "bioprojects"


BioProject = NewType("BioProject", str)
Sample = NewType("Sample", str)


target_bioprojects = {
"crits_christoph": [BioProject("PRJNA661613")],
"rothman": [BioProject("PRJNA729801")],
"spurbeck": [BioProject("PRJNA924011")],
"brinch": [BioProject("PRJEB13832"), BioProject("PRJEB34633")],
"crits_christoph": [BioProject("CC-PRJNA661613")],
"rothman": [BioProject("Rothman-PRJNA729801")],
"spurbeck": [BioProject("Spurbeck-PRJNA924011")],
"brinch": [BioProject("Brinch-PRJEB13832"), BioProject("Brinch-PRJEB34633")],
}


@dataclass
class GitHubRepo:
user: str
repo: str
ref: str

def get_file(self, path: str) -> str:
file_url = (
f"https://raw.githubusercontent.com/"
f"{self.user}/{self.repo}/{self.ref}/{path}"
)
with urllib.request.urlopen(file_url) as response:
if response.status == 200:
return response.read()
else:
raise ValueError(
f"Failed to download {file_url}. "
f"Response status code: {response.status}"
)


def load_bioprojects(repo: GitHubRepo) -> dict[BioProject, list[Sample]]:
data = json.loads(repo.get_file("dashboard/metadata_bioprojects.json"))
return {
BioProject(bp): [Sample(s) for s in samples]
for bp, samples in data.items()
}


class Enrichment(Enum):
VIRAL = "viral"
PANEL = "panel"
Expand All @@ -77,28 +48,151 @@ class SampleAttributes(BaseModel):
method: Optional[str] = None


def load_sample_attributes(repo: GitHubRepo) -> dict[Sample, SampleAttributes]:
data = json.loads(repo.get_file("dashboard/metadata_samples.json"))
return {
Sample(s): SampleAttributes(**attribs) for s, attribs in data.items()
}


SampleCounts = dict[TaxID, dict[Sample, int]]
def european_to_iso(date):
dd,mm,yyyy = date.split("/")
return "%s-%s-%s"%(yyyy,mm,dd)

def parse_metadata(record, paper):
if paper == "rothman":
sample,library,date,location,enrichment,sample_alias,dataset,bioproject = record
wtp = sample_alias.split("_")[0]
if wtp == "JW":
# Rothman confirmed over email that JW = JWPCP.
wtp = "JWPCP"


return sample, SampleAttributes(
country = "United States",
date=date,
state="California",
location="Los Angeles",
county={
# Hyperion
"HTP": "Los Angeles County",
# San Jose Creek
"SJ": "Los Angeles County",
# Joint Water Pollution Control Plant
"JWPCP": "Los Angeles County",
# Orange County
"OC": "Orange County",
# Point Loma
"PL": "San Diego County",
# South Bay
"SB": "San Diego County",
# North City
"NC": "San Diego County",
# Escondido Hale Avenue Resource Recovery Facility
"ESC": "San Diego County",
}[wtp],
fine_location=wtp,
enrichment="panel" if enrichment == "1" else "viral",
)
elif paper == "crits_christoph":
library,sample,location,date,method,enrichment,sample_alias,dataset,bioproject = record
return sample, SampleAttributes(
date=european_to_iso(date),
country="United States",
state="California",
location="San Francisco",
county={
"Berkeley": "Alameda County",
"Marin": "Marin County",
"Oakland": "Alameda County",
"SF": "San Francisco County",
}[location],
fine_location=location,
method=method,
enrichment="panel" if enrichment == "enriched" else "viral",
)
elif paper == "spurbeck":
library,sample,group,date,instrument_model,sample_alias,bioproject,dataset = record
return sample, SampleAttributes(
date=european_to_iso(date),
country="United States",
state="Ohio",
location="Ohio",
# https://github.com/naobservatory/mgs-pipeline/issues/9
county={
"A": "Summit County",
"B": "Trumbull County",
"C": "Lucas County",
"D": "Lawrence County",
"E": "Sandusky County",
"F": "Franklin County",
"G": "Licking County",
"H": "Franklin County",
"I": "Greene County",
"J": "Montgomery County",
}[group],
fine_location=group,
enrichment="viral",
method={
"A": "AB",
"B": "AB",
"C": "C",
"D": "D",
"E": "EFGH",
"F": "EFGH",
"G": "EFGH",
"H": "EFGH",
"I": "IJ",
"J": "IJ",
}[group],
)
elif paper == "brinch":
library,sample,location,date = record
return sample, SampleAttributes(
date=date,
country="Denmark",
location="Copenhagen",
fine_location=location,
)
else:
assert False


def load_sample_counts(repo: GitHubRepo) -> SampleCounts:
data: dict[str, dict[str, int]] = json.loads(
repo.get_file("dashboard/human_virus_sample_counts.json")
)
return {
TaxID(int(taxid)): {Sample(sample): n for sample, n in counts.items()}
for taxid, counts in data.items()
}
import pprint

SampleCounts = dict[TaxID, dict[Sample, int]]

def load_tax_tree(repo: GitHubRepo) -> Tree[TaxID]:
data = json.loads(repo.get_file("dashboard/human_virus_tree.json"))
metadata_bioprojects = {}
metadata_samples = {}
sample_counts = defaultdict(dict)
for paper, bioprojects in target_bioprojects.items():
for bioproject in bioprojects:
samples = []
with open (os.path.join(BIOPROJECTS_DIR, bioproject, "sample-metadata.csv")) as inf:
for i, record in enumerate(csv.reader(inf)):
if i == 0:
continue
sample, sample_attributes = parse_metadata(record, paper)
samples.append(sample)
metadata_samples[sample] = sample_attributes
metadata_bioprojects[bioproject] = samples
with open (os.path.join(BIOPROJECTS_DIR, bioproject, "hv_clade_counts.tsv")) as inf:
for i, row in enumerate(inf):
if i == 0:
continue
taxid, name, rank, parent_taxid, sample, n_reads_direct, n_reads_clade = row.rstrip("\n").split("\t")
taxid = int(taxid)
n_reads_direct = int(n_reads_direct)
if n_reads_direct:
sample_counts[taxid][sample] = n_reads_direct
with open(os.path.join(BIOPROJECTS_DIR, bioproject, "qc_basic_stats.tsv")) as inf:
for i, row in enumerate(inf):
row = row.rstrip("\n").split("\t")

if i == 0:
cols = row
continue

metadata_samples[row[cols.index("sample")]].reads = int(row[cols.index("n_read_pairs")])



def load_tax_tree() -> Tree[TaxID]:
with open("human_virus_tree-2022-12.json") as inf:
data = json.load(inf)
return Tree.tree_from_list(data).map(lambda x: TaxID(int(x)))


Expand Down Expand Up @@ -133,17 +227,12 @@ class MGSData:

@staticmethod
def from_repo(
user=MGS_REPO_DEFAULTS["user"],
repo=MGS_REPO_DEFAULTS["repo"],
ref=MGS_REPO_DEFAULTS["ref"],
):
repo = GitHubRepo(user, repo, ref)
print(repo, type(repo))
return MGSData(
bioprojects=load_bioprojects(repo),
sample_attrs=load_sample_attributes(repo),
read_counts=load_sample_counts(repo),
tax_tree=load_tax_tree(repo),
bioprojects=metadata_bioprojects,
sample_attrs=metadata_samples,
read_counts=sample_counts,
tax_tree=load_tax_tree(),
)

def sample_attributes(
Expand Down

0 comments on commit 8a85117

Please sign in to comment.