Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use .walk file in complexity command #10

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
],
"settings": {
"python.analysis.typeCheckingMode": "strict",
// "python.analysis.exclude": [
// "**/.vscode-remote/**"
// ],
"python.condaPath": "/opt/conda/condabin/conda",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
Expand Down
35 changes: 25 additions & 10 deletions panct/data/walks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ class Walks(Data):

Attributes
----------
data : dict[str, Counter[tuple[str, int]]]
data : dict[int, Counter[tuple[str, int]]]
A bunch of nodes, stored as a mapping of node IDs to tuples of (sample labels, haplotype ID)
log: Logger
A logging instance for recording debug statements.
"""

def __init__(self, data: dict[str, Counter[tuple[str, int]]], log: Logger = None):
def __init__(self, data: dict[int, Counter[tuple[str, int]]], log: Logger = None):
super().__init__(log=log)
self.data = data

Expand All @@ -34,7 +34,11 @@ def __len__(self):

@classmethod
def read(
cls: Type[Walks], fname: Path | str, region: str = None, log: Logger = None
cls: Type[Walks],
fname: Path | str,
region: str = None,
nodes: set[int] = None,
log: Logger = None,
) -> Walks:
"""
Extract walks from a .walk file
Expand All @@ -46,6 +50,8 @@ def read(
region: str, optional
A region string denoting the start and end node IDs in the form
of f'{start}-{end}'
nodes: set[int], optional
A subset of nodes to load. Defaults to all nodes.
log: Logger, optional
A Logger object to use for debugging statements

Expand All @@ -54,7 +60,7 @@ def read(
Walks
A Walks object loaded with a bunch of Node objects
"""
nodes = {}
final_nodes = {}
parse_samp = lambda samp: (samp[0], int(samp[1]))
# Try to read the file with tabix
if Path(fname).suffix == ".gz" and region is not None:
Expand All @@ -66,11 +72,18 @@ def read(
for line in f.fetch(region=region_str):
samples = line.strip().split("\t")
node = int(samples.pop(0))

nodes[node] = Counter(
if nodes is not None and node not in nodes:
continue
final_nodes[node] = Counter(
parse_samp(samp.rsplit(":", 1)) for samp in samples
)
return cls(nodes, log)
if (
log is not None
and nodes is not None
and len(final_nodes) < len(nodes)
):
log.warning("Couldn't load all requested nodes")
return cls(final_nodes, log)
except ValueError:
pass
# If we couldn't parse with tabix, then fall back to slow loading
Expand All @@ -88,9 +101,11 @@ def read(
for line in f:
samples = str(line.strip())
node = int(samples.split("\t", maxsplit=1)[0])
if node < start or node > end:
if (node < start or node > end) or (
nodes is not None and node not in nodes
):
continue
nodes[node] = Counter(
final_nodes[node] = Counter(
parse_samp(samp.rsplit(":", 1)) for samp in samples.split("\t")[1:]
)
return cls(nodes, log)
return cls(final_nodes, log)
64 changes: 47 additions & 17 deletions panct/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
"""

from pathlib import Path
from collections import Counter

import numpy as np

from .data import Walks


class Node:
"""
Expand Down Expand Up @@ -221,20 +224,47 @@ def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []):
if not walk_file.exists():
walk_file = walk_file.with_suffix(".walk.gz")

# TODO: get nodes from .walk file and add with self.add_walk()
# if walk_file.exists():
# else:

# Second pass to get the walks
with open(gfa_file, "r") as f:
for line in f:
linetype = line.split()[0]
if linetype != "W":
continue
sampid = line.split()[1]
if sampid in exclude_samples:
continue
hapid = line.split()[2]
walk = line.split()[6]
nodes = self.get_nodes_from_walk(walk)
self.add_walk(f"{sampid}:{hapid}", nodes)
if walk_file.exists():
node_set = set(int(n) for n in self.nodes.keys())
# find smallest and largest node for processing walks
smallest_node = min(node_set, default="")
largest_node = max(node_set, default="")
# Get nodes from .walk file and add with self.add_walk()
walks = Walks.read(
walk_file,
region=f"{smallest_node}-{largest_node}",
nodes=node_set,
log=None, # TODO: pass Logger
)
# check that all of the nodes were loaded properly
# TODO: remove this check? or implement a fail-safe
assert len(walks.data) == len(node_set)
assert len(node_set) == len(self.nodes)
# TODO: implement exclude_samples
walk_lengths = Counter()
all_samples = set()
for node, node_val in self.nodes.items():
node_int = int(node)
samples = set(f"{hap[0]}:{hap[1]}" for hap in walks.data[node_int])
all_samples.update(samples)
node_val.samples.update(samples)
for sampid, hapid in walks.data[node_int]:
# how many times did this haplotype pass through this node?
num_times = walks.data[node_int][(sampid, hapid)]
walk_lengths[f"{sampid}:{hapid}"] += node_val.length * num_times
self.numwalks += len(all_samples)
self.walk_lengths.extend(walk_lengths.values())
else:
# Second pass over gfa file to get the walks
with open(gfa_file, "r") as f:
for line in f:
linetype = line.split()[0]
if linetype != "W":
continue
sampid = line.split()[1]
if sampid in exclude_samples:
continue
hapid = line.split()[2]
walk = line.split()[6]
nodes = self.get_nodes_from_walk(walk)
self.add_walk(f"{sampid}:{hapid}", nodes)
9 changes: 9 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,19 @@ def test_parse_walks_file(self):
nodes = Walks.read(DATADIR / "basic.walk.gz", region="-2")
assert nodes.data == expected.data

nodes = Walks.read(DATADIR / "basic.walk.gz", region="-2", nodes={1, 2})
assert nodes.data == expected.data

del expected.data[2]

nodes = Walks.read(DATADIR / "basic.walk", region="1-1")
assert nodes.data == expected.data

nodes = Walks.read(DATADIR / "basic.walk.gz", region="1-1")
assert nodes.data == expected.data

nodes = Walks.read(DATADIR / "basic.walk.gz", region="1-2", nodes=set((1,)))
assert nodes.data == expected.data

nodes = Walks.read(DATADIR / "basic.walk.gz", nodes=set((1,)))
assert nodes.data == expected.data
7 changes: 4 additions & 3 deletions tests/test_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import os
from panct.graph_utils import *
from pathlib import Path

import pytest
import numpy as np

from panct.graph_utils import Node, NodeTable

DATADIR = Path(__file__).parent.joinpath("data")

Expand Down
Loading