From dcb538fd5bf2881c00e312796a135f12fefc6dcd Mon Sep 17 00:00:00 2001 From: Arya Massarat <23412689+aryarm@users.noreply.github.com> Date: Wed, 5 Feb 2025 21:18:01 +0000 Subject: [PATCH 1/4] wip: read walks from .walk in complexity - initial implementation --- panct/data/walks.py | 4 +-- panct/graph_utils.py | 70 +++++++++++++++++++++++++++++++++----------- 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/panct/data/walks.py b/panct/data/walks.py index bf28eb9..05d0a2f 100644 --- a/panct/data/walks.py +++ b/panct/data/walks.py @@ -19,13 +19,13 @@ class Walks(Data): Attributes ---------- - data : dict[str, Counter[tuple[str, int]]] + data : dict[int, Counter[tuple[str, int]]] A bunch of nodes, stored as a mapping of node IDs to tuples of (sample labels, haplotype ID) log: Logger A logging instance for recording debug statements. """ - def __init__(self, data: dict[str, Counter[tuple[str, int]]], log: Logger = None): + def __init__(self, data: dict[int, Counter[tuple[str, int]]], log: Logger = None): super().__init__(log=log) self.data = data diff --git a/panct/graph_utils.py b/panct/graph_utils.py index dac4f2c..aa98cd8 100644 --- a/panct/graph_utils.py +++ b/panct/graph_utils.py @@ -3,9 +3,12 @@ """ from pathlib import Path +from collections import Counter import numpy as np +from .data import Walks + class Node: """ @@ -193,6 +196,9 @@ def get_nodes_from_walk(self, walk_string: str) -> list[str]: return ws.split(":") def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []): + # keep track of smallest and largest node for extracting from .walk file + smallest_node, largest_node = -float("inf"), float("inf") + # First parse all the nodes with open(gfa_file, "r") as f: for line in f: @@ -211,6 +217,22 @@ def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []): if nodelen == 0: raise ValueError(f"Could not determine node length for {nodeid}") self.add_node(Node(nodeid, length=nodelen)) + # keep track of the smallest and largest nodes + try: + # if the node can't be parsed into an int, then just move on + nodeid = int(nodeid) + except: + continue + if nodeid < smallest_node: + smallest_node = nodeid + elif nodeid > largest_node: + largest_node = nodeid + + # fix smallest and largest node for processing walks + if smallest_node == -float("inf"): + smallest_node = "" + if largest_node == float("inf"): + largest_node = "" # try to find the .walk file walk_file = Path("") @@ -221,20 +243,34 @@ def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []): if not walk_file.exists(): walk_file = walk_file.with_suffix(".walk.gz") - # TODO: get nodes from .walk file and add with self.add_walk() - # if walk_file.exists(): - # else: - - # Second pass to get the walks - with open(gfa_file, "r") as f: - for line in f: - linetype = line.split()[0] - if linetype != "W": - continue - sampid = line.split()[1] - if sampid in exclude_samples: - continue - hapid = line.split()[2] - walk = line.split()[6] - nodes = self.get_nodes_from_walk(walk) - self.add_walk(f"{sampid}:{hapid}", nodes) + if walk_file.exists(): + # Get nodes from .walk file and add with self.add_walk() + walks = Walks.read(walk_file, region=f"{smallest_node}-{largest_node}") + # TODO: implement exclude_samples + walk_lengths = Counter() + all_samples = set() + for node, node_val in self.nodes.items(): + node_int = int(node) + samples = set(f"{hap[0]}:{hap[1]}" for hap in walks.data[node_int]) + all_samples.update(samples) + node_val.samples.update(samples) + for sampid, hapid in walks.data[node_int]: + # how many times did this haplotype pass through this node? + num_times = walks.data[node_int][(sampid, hapid)] + walk_lengths[f"{sampid}:{hapid}"] += node_val.length * num_times + self.numwalks += len(all_samples) + self.walk_lengths.extend(walk_lengths.values()) + else: + # Second pass over gfa file to get the walks + with open(gfa_file, "r") as f: + for line in f: + linetype = line.split()[0] + if linetype != "W": + continue + sampid = line.split()[1] + if sampid in exclude_samples: + continue + hapid = line.split()[2] + walk = line.split()[6] + nodes = self.get_nodes_from_walk(walk) + self.add_walk(f"{sampid}:{hapid}", nodes) From 7bc86884238afa5ae5846d3e05cd2119888dc9e8 Mon Sep 17 00:00:00 2001 From: Arya Massarat <23412689+aryarm@users.noreply.github.com> Date: Wed, 12 Feb 2025 21:38:43 +0000 Subject: [PATCH 2/4] add nodes arg to Walks class --- panct/data/walks.py | 31 +++++++++++++++++++++++-------- tests/test_data.py | 9 +++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/panct/data/walks.py b/panct/data/walks.py index 05d0a2f..07abeb2 100644 --- a/panct/data/walks.py +++ b/panct/data/walks.py @@ -34,7 +34,11 @@ def __len__(self): @classmethod def read( - cls: Type[Walks], fname: Path | str, region: str = None, log: Logger = None + cls: Type[Walks], + fname: Path | str, + region: str = None, + nodes: set[int] = None, + log: Logger = None, ) -> Walks: """ Extract walks from a .walk file @@ -46,6 +50,8 @@ def read( region: str, optional A region string denoting the start and end node IDs in the form of f'{start}-{end}' + nodes: set[int], optional + A subset of nodes to load. Defaults to all nodes. log: Logger, optional A Logger object to use for debugging statements @@ -54,7 +60,7 @@ def read( Walks A Walks object loaded with a bunch of Node objects """ - nodes = {} + final_nodes = {} parse_samp = lambda samp: (samp[0], int(samp[1])) # Try to read the file with tabix if Path(fname).suffix == ".gz" and region is not None: @@ -66,11 +72,18 @@ def read( for line in f.fetch(region=region_str): samples = line.strip().split("\t") node = int(samples.pop(0)) - - nodes[node] = Counter( + if nodes is not None and node not in nodes: + continue + final_nodes[node] = Counter( parse_samp(samp.rsplit(":", 1)) for samp in samples ) - return cls(nodes, log) + if ( + log is not None + and nodes is not None + and len(final_nodes) < len(nodes) + ): + log.warning("Couldn't load all requested nodes") + return cls(final_nodes, log) except ValueError: pass # If we couldn't parse with tabix, then fall back to slow loading @@ -88,9 +101,11 @@ def read( for line in f: samples = str(line.strip()) node = int(samples.split("\t", maxsplit=1)[0]) - if node < start or node > end: + if (node < start or node > end) or ( + nodes is not None and node not in nodes + ): continue - nodes[node] = Counter( + final_nodes[node] = Counter( parse_samp(samp.rsplit(":", 1)) for samp in samples.split("\t")[1:] ) - return cls(nodes, log) + return cls(final_nodes, log) diff --git a/tests/test_data.py b/tests/test_data.py index 7e51e3c..99230e6 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -85,6 +85,9 @@ def test_parse_walks_file(self): nodes = Walks.read(DATADIR / "basic.walk.gz", region="-2") assert nodes.data == expected.data + nodes = Walks.read(DATADIR / "basic.walk.gz", region="-2", nodes={1, 2}) + assert nodes.data == expected.data + del expected.data[2] nodes = Walks.read(DATADIR / "basic.walk", region="1-1") @@ -92,3 +95,9 @@ def test_parse_walks_file(self): nodes = Walks.read(DATADIR / "basic.walk.gz", region="1-1") assert nodes.data == expected.data + + nodes = Walks.read(DATADIR / "basic.walk.gz", region="1-2", nodes=set((1,))) + assert nodes.data == expected.data + + nodes = Walks.read(DATADIR / "basic.walk.gz", nodes=set((1,))) + assert nodes.data == expected.data From 8f2605d3d7aabf351b24d31f32659be407d96624 Mon Sep 17 00:00:00 2001 From: Arya Massarat <23412689+aryarm@users.noreply.github.com> Date: Wed, 12 Feb 2025 21:39:39 +0000 Subject: [PATCH 3/4] use nodes arg when loading from walks file in graph_utils --- .devcontainer.json | 3 +++ panct/graph_utils.py | 34 ++++++++++++++-------------------- tests/test_graph_utils.py | 7 ++++--- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/.devcontainer.json b/.devcontainer.json index 4c65ff5..2dbb58c 100644 --- a/.devcontainer.json +++ b/.devcontainer.json @@ -27,6 +27,9 @@ ], "settings": { "python.analysis.typeCheckingMode": "strict", + // "python.analysis.exclude": [ + // "**/.vscode-remote/**" + // ], "python.condaPath": "/opt/conda/condabin/conda", "python.terminal.activateEnvironment": true, "python.terminal.activateEnvInCurrentTerminal": true, diff --git a/panct/graph_utils.py b/panct/graph_utils.py index aa98cd8..3065a2d 100644 --- a/panct/graph_utils.py +++ b/panct/graph_utils.py @@ -196,9 +196,6 @@ def get_nodes_from_walk(self, walk_string: str) -> list[str]: return ws.split(":") def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []): - # keep track of smallest and largest node for extracting from .walk file - smallest_node, largest_node = -float("inf"), float("inf") - # First parse all the nodes with open(gfa_file, "r") as f: for line in f: @@ -217,22 +214,6 @@ def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []): if nodelen == 0: raise ValueError(f"Could not determine node length for {nodeid}") self.add_node(Node(nodeid, length=nodelen)) - # keep track of the smallest and largest nodes - try: - # if the node can't be parsed into an int, then just move on - nodeid = int(nodeid) - except: - continue - if nodeid < smallest_node: - smallest_node = nodeid - elif nodeid > largest_node: - largest_node = nodeid - - # fix smallest and largest node for processing walks - if smallest_node == -float("inf"): - smallest_node = "" - if largest_node == float("inf"): - largest_node = "" # try to find the .walk file walk_file = Path("") @@ -244,8 +225,21 @@ def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []): walk_file = walk_file.with_suffix(".walk.gz") if walk_file.exists(): + node_set = set(int(n) for n in self.nodes.keys()) + # find smallest and largest node for processing walks + smallest_node = min(node_set, default="") + largest_node = max(node_set, default="") # Get nodes from .walk file and add with self.add_walk() - walks = Walks.read(walk_file, region=f"{smallest_node}-{largest_node}") + walks = Walks.read( + walk_file, + region=f"{smallest_node}-{largest_node}", + nodes=node_set, + log=None, # TODO: pass Logger + ) + # check that all of the nodes were loaded properly + # TODO: remove this check? or implement a fail-safe + assert len(walks.data) == len(node_set) + assert len(node_set) == len(self.nodes) # TODO: implement exclude_samples walk_lengths = Counter() all_samples = set() diff --git a/tests/test_graph_utils.py b/tests/test_graph_utils.py index 4620680..b62ccfc 100644 --- a/tests/test_graph_utils.py +++ b/tests/test_graph_utils.py @@ -1,8 +1,9 @@ -import numpy as np -import os -from panct.graph_utils import * from pathlib import Path + import pytest +import numpy as np + +from panct.graph_utils import Node, NodeTable DATADIR = Path(__file__).parent.joinpath("data") From eccdd00e82743454faa9a996affadbf7fc486a3f Mon Sep 17 00:00:00 2001 From: Arya Massarat <23412689+aryarm@users.noreply.github.com> Date: Wed, 19 Feb 2025 22:45:42 +0000 Subject: [PATCH 4/4] pass walk file from original gbz path with temp gfa file --- panct/gbz_utils.py | 7 ++++++- panct/graph_utils.py | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/panct/gbz_utils.py b/panct/gbz_utils.py index 2b015c5..44be9bd 100644 --- a/panct/gbz_utils.py +++ b/panct/gbz_utils.py @@ -140,4 +140,9 @@ def load_node_table_from_gbz( gfa_file = extract_region_from_gbz(gbz_file, region, reference) if gfa_file is None: return gutils.NodeTable() - return gutils.NodeTable(gfa_file=gfa_file, exclude_samples=[reference]) + walk_file = gbz_file.with_suffix(".walk") + if not walk_file.exists(): + walk_file = walk_file.with_suffix(".walk.gz") + if not walk_file.exists(): + walk_file = None + return gutils.NodeTable(gfa_file=gfa_file, exclude_samples=[reference], walk_file=walk_file) diff --git a/panct/graph_utils.py b/panct/graph_utils.py index 3065a2d..a2a1592 100644 --- a/panct/graph_utils.py +++ b/panct/graph_utils.py @@ -82,12 +82,12 @@ class NodeTable: Get list of nodes from the walk """ - def __init__(self, gfa_file: Path = None, exclude_samples: list[str] = []): + def __init__(self, gfa_file: Path = None, exclude_samples: list[str] = [], walk_file: Path = None): self.nodes = {} # node ID-> Node self.numwalks = 0 self.walk_lengths = [] if gfa_file is not None: - self.load_from_gfa(gfa_file, exclude_samples) + self.load_from_gfa(gfa_file, exclude_samples, walk_file) def add_node(self, node: Node): """ @@ -195,7 +195,7 @@ def get_nodes_from_walk(self, walk_string: str) -> list[str]: ws = walk_string.replace(">", ":").replace("<", ":").strip(":") return ws.split(":") - def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []): + def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = [], walk_file: Path = None): # First parse all the nodes with open(gfa_file, "r") as f: for line in f: @@ -216,13 +216,13 @@ def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []): self.add_node(Node(nodeid, length=nodelen)) # try to find the .walk file - walk_file = Path("") - if gfa_file.suffix == ".gz": - walk_file = gfa_file.with_suffix("").with_suffix(".walk") - else: - walk_file = gfa_file.with_suffix("") - if not walk_file.exists(): - walk_file = walk_file.with_suffix(".walk.gz") + if walk_file is None: + if gfa_file.suffix == ".gz": + walk_file = gfa_file.with_suffix("").with_suffix(".walk") + else: + walk_file = gfa_file.with_suffix(".walk") + if not walk_file.exists(): + walk_file = walk_file.with_suffix(".walk.gz") if walk_file.exists(): node_set = set(int(n) for n in self.nodes.keys())