Skip to content

Commit

Permalink
Tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
jfnavarro committed Jan 12, 2025
1 parent 72aa065 commit 0d2a869
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 20 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ include = [
{ path = "README.md" },
{ path = "README_SHORT.md" },
{ path = "LICENSE" },
{ path = "doc/**" }
{ path = "doc/**" },
{ path = "scripts/**" }
]

[tool.poetry.dependencies]
Expand Down
18 changes: 10 additions & 8 deletions stpipeline/common/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ def _breadth_first_search(node: str, adj_list: Dict[str, List[str]]) -> Set[str]
Performs a breadth-first search (BFS) to find all connected components starting from a node.
"""
searched = set()
found = set(node)
queue = set(node)
while len(queue) > 0:
node = (list(queue))[0]
found.update(adj_list[node])
queue.update(adj_list[node])
searched.add(node)
queue.difference_update(searched)
queue = {node}
found = set(queue)
while queue:
current = queue.pop()
searched.add(current)
# Convert neighbors to a set to handle list inputs
neighbors = set(adj_list[current]) - searched
found.update(neighbors)
# Add new neighbors to the queue
queue.update(neighbors)
return found


Expand Down
2 changes: 1 addition & 1 deletion stpipeline/common/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _compute_saturation_metrics(

results["reads"].append(stats["reads_after_duplicates_removal"])
results["genes"].append(stats["genes_found"])
results["avg_genes"].append(stats["average_gene_feature"])
results["avg_genes"].append(stats["average_genes_feature"])
results["avg_reads"].append(stats["average_reads_feature"])

return results
Expand Down
19 changes: 10 additions & 9 deletions tests/clustering_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_remove_umis():
"B": ["A", "C"],
"C": ["B"],
}
cluster = {"A", "B", "C"}
nodes = ["B"]
cluster = ["A", "B", "C"]
nodes = ["C"]
result = _remove_umis(adj_list, cluster, nodes)
assert result == {"A"}

Expand All @@ -50,9 +50,10 @@ def test_get_connected_components_adjacency():
}
counts = Counter({"A": 3, "B": 2, "C": 1, "D": 4})
result = _get_connected_components_adjacency(adj_list, counts)
result = [sorted(x) for x in result]
assert len(result) == 2
assert {"A", "B", "C"} in result
assert {"D"} in result
assert ["A", "B", "C"] in result
assert ["D"] in result


def test_get_adj_list_adjacency():
Expand All @@ -72,7 +73,7 @@ def test_get_best_adjacency():
cluster = ["A", "B", "C"]
counts = Counter({"A": 3, "B": 2, "C": 1})
result = _get_best_adjacency(cluster, adj_list, counts)
assert result == ["A"]
assert result == ["A", "B"]


def test_reduce_clusters_adjacency():
Expand All @@ -84,22 +85,22 @@ def test_reduce_clusters_adjacency():
clusters = [{"A", "B", "C"}]
counts = Counter({"A": 3, "B": 2, "C": 1})
result = _reduce_clusters_adjacency(adj_list, clusters, counts)
assert result == ["A"]
assert result == ["A", "B"]


def test_get_adj_list_directional_adjacency():
umis = ["AAAA", "AAAT", "AATT", "TTTT"]
counts = Counter({"AAAA": 4, "AAAT": 3, "AATT": 2, "TTTT": 1})
counts = Counter({"AAAA": 6, "AAAT": 3, "AATT": 2, "TTTT": 1})
allowed_mismatches = 1
result = _get_adj_list_directional_adjacency(umis, counts, allowed_mismatches)
assert "AAAA" in result and "AAAT" in result["AAAA"]
assert "AATT" not in result["AAAA"]


def test_reduce_clusters_directional_adjacency():
clusters = [{"A", "B", "C"}]
clusters = [["A", "B", "C"]]
result = _reduce_clusters_directional_adjacency(clusters)
assert result == ["A"]
assert result == ["C"]


def test_dedup_hierarchical():
Expand Down
7 changes: 6 additions & 1 deletion tests/saturation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def mock_bam_file(tmp_path):
segment.flag = 0
segment.reference_id = 0
segment.reference_start = i * 10
segment.cigar = [(0, len(segment.query_sequence))] # 0: MATCH
segment.set_tag("B1", i)
segment.set_tag("B2", i * 2)
segment.set_tag("XF", "gene1")
segment.set_tag("B3", "UMI1")
bam_file.write(segment)
return str(bam_path), 100

Expand Down Expand Up @@ -95,7 +100,7 @@ def test_compute_saturation_metrics(mock_bam_file, tmp_path):
saturation_points = [10, 50, 100]
temp_folder = tmp_path
gff_filename = tmp_path / "mock.gff"
gff_filename.write_text("chr1\tsource\tfeature\t1\t1000\t.\t+\t.\tID=gene1\n")
gff_filename.write_text("chr1\tsource\tfeature\t1\t1000\t.\t+\t.\tgene_id=gene1\n")

files, file_names, subsampling = _generate_subsamples(nreads, bam_file, saturation_points, temp_folder)
_write_subsamples_to_files(files, subsampling, bam_file, saturation_points)
Expand Down

0 comments on commit 0d2a869

Please sign in to comment.