Skip to content

Commit

Permalink
Most tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
jfnavarro committed Jan 12, 2025
1 parent 2c78a00 commit a3a7d92
Show file tree
Hide file tree
Showing 13 changed files with 457 additions and 295 deletions.
10 changes: 7 additions & 3 deletions tests/annotation_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#! /usr/bin/env python
"""
"""
Unit-test the package annotation
"""
import pytest
import os
import pysam
from unittest.mock import patch, Mock
from stpipeline.core.annotation import invert_strand, count_reads_in_features, annotateReads
import HTSeq


@pytest.fixture
def mock_gff_file(tmp_path):
gff_content = (
Expand All @@ -21,6 +21,7 @@ def mock_gff_file(tmp_path):
f.write(gff_content)
return str(gff_file)


@pytest.fixture
def mock_bam_file(tmp_path):
bam_file = tmp_path / "mock.bam"
Expand All @@ -47,6 +48,7 @@ def mock_bam_file(tmp_path):
f.write(segment)
return str(bam_file)


def test_invert_strand():
iv = HTSeq.GenomicInterval("chr1", 0, 1000, "+")
inverted = invert_strand(iv)
Expand All @@ -60,6 +62,7 @@ def test_invert_strand():
iv = HTSeq.GenomicInterval("chr1", 0, 1000, ".")
invert_strand(iv)


def test_count_reads_in_features(mock_bam_file, mock_gff_file, tmp_path):
output_file = tmp_path / "output.bam"
discarded_file = tmp_path / "discarded.bam"
Expand All @@ -83,6 +86,7 @@ def test_count_reads_in_features(mock_bam_file, mock_gff_file, tmp_path):
assert os.path.exists(output_file)
assert os.path.exists(discarded_file)


def test_annotate_reads(mock_bam_file, mock_gff_file, tmp_path):
output_file = tmp_path / "output.bam"
discarded_file = tmp_path / "discarded.bam"
Expand All @@ -100,4 +104,4 @@ def test_annotate_reads(mock_bam_file, mock_gff_file, tmp_path):
)

assert os.path.exists(output_file)
assert os.path.exists(discarded_file)
assert os.path.exists(discarded_file)
51 changes: 37 additions & 14 deletions tests/clustering_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
#! /usr/bin/env python
"""
"""
Unit-test the package clustering
"""
from collections import Counter
from stpipeline.common.clustering import _breadth_first_search, _remove_umis, _get_connected_components_adjacency, _get_adj_list_adjacency, _get_best_adjacency, _reduce_clusters_adjacency, _get_adj_list_directional_adjacency, _reduce_clusters_directional_adjacency, dedup_hierarchical, dedup_adj, dedup_dir_adj
from stpipeline.common.clustering import (
_breadth_first_search,
_remove_umis,
_get_connected_components_adjacency,
_get_adj_list_adjacency,
_get_best_adjacency,
_reduce_clusters_adjacency,
_get_adj_list_directional_adjacency,
_reduce_clusters_directional_adjacency,
dedup_hierarchical,
dedup_adj,
dedup_dir_adj,
)


def test_breadth_first_search():
adj_list = {
Expand All @@ -15,6 +28,7 @@ def test_breadth_first_search():
result = _breadth_first_search("A", adj_list)
assert result == {"A", "B", "C", "D"}


def test_remove_umis():
adj_list = {
"A": ["B"],
Expand All @@ -26,6 +40,7 @@ def test_remove_umis():
result = _remove_umis(adj_list, cluster, nodes)
assert result == {"A"}


def test_get_connected_components_adjacency():
adj_list = {
"A": ["B"],
Expand All @@ -39,12 +54,14 @@ def test_get_connected_components_adjacency():
assert {"A", "B", "C"} in result
assert {"D"} in result


def test_get_adj_list_adjacency():
umis = ["AAAA", "AAAT", "AATT", "TTTT"]
allowed_mismatches = 1
result = _get_adj_list_adjacency(umis, allowed_mismatches)
assert "AAAA" in result and "AAAT" in result["AAAA"]
assert "AATT" not in result["AAAA"]
umis = ["AAAA", "AAAT", "AATT", "TTTT"]
allowed_mismatches = 1
result = _get_adj_list_adjacency(umis, allowed_mismatches)
assert "AAAA" in result and "AAAT" in result["AAAA"]
assert "AATT" not in result["AAAA"]


def test_get_best_adjacency():
adj_list = {
Expand All @@ -57,6 +74,7 @@ def test_get_best_adjacency():
result = _get_best_adjacency(cluster, adj_list, counts)
assert result == ["A"]


def test_reduce_clusters_adjacency():
adj_list = {
"A": ["B"],
Expand All @@ -68,33 +86,38 @@ def test_reduce_clusters_adjacency():
result = _reduce_clusters_adjacency(adj_list, clusters, counts)
assert result == ["A"]


def test_get_adj_list_directional_adjacency():
umis = ["AAAA", "AAAT", "AATT", "TTTT"]
counts = Counter({"AAAA": 4, "AAAT": 3, "AATT": 2, "TTTT": 1})
allowed_mismatches = 1
result = _get_adj_list_directional_adjacency(umis, counts, allowed_mismatches)
assert "AAAA" in result and "AAAT" in result["AAAA"]
assert "AATT" not in result["AAAA"]
umis = ["AAAA", "AAAT", "AATT", "TTTT"]
counts = Counter({"AAAA": 4, "AAAT": 3, "AATT": 2, "TTTT": 1})
allowed_mismatches = 1
result = _get_adj_list_directional_adjacency(umis, counts, allowed_mismatches)
assert "AAAA" in result and "AAAT" in result["AAAA"]
assert "AATT" not in result["AAAA"]


def test_reduce_clusters_directional_adjacency():
clusters = [{"A", "B", "C"}]
result = _reduce_clusters_directional_adjacency(clusters)
assert result == ["A"]


def test_dedup_hierarchical():
umis = ["AAAA", "AAAT", "AATT", "TTTT"]
allowed_mismatches = 1
result = dedup_hierarchical(umis, allowed_mismatches)
assert len(result) <= len(umis)


def test_dedup_adj():
umis = ["AAAA", "AAAT", "AATT", "TTTT"]
allowed_mismatches = 1
result = dedup_adj(umis, allowed_mismatches)
assert len(result) <= len(umis)


def test_dedup_dir_adj():
umis = ["AAAA", "AAAT", "AATT", "TTTT"]
allowed_mismatches = 1
result = dedup_dir_adj(umis, allowed_mismatches)
assert len(result) <= len(umis)
assert len(result) <= len(umis)
25 changes: 11 additions & 14 deletions tests/dataset_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#! /usr/bin/env python
"""
"""
Unit-test the package dataset
"""
import pytest
Expand All @@ -9,6 +9,7 @@
from unittest.mock import Mock
import pysam


@pytest.fixture
def mock_gff_file(tmp_path):
gff_content = (
Expand All @@ -21,6 +22,7 @@ def mock_gff_file(tmp_path):
f.write(gff_content)
return str(gff_file)


@pytest.fixture
def mock_bam_file(tmp_path):
bam_file = tmp_path / "mock.bam"
Expand All @@ -41,16 +43,11 @@ def mock_bam_file(tmp_path):
f.write(segment)
return str(bam_file)


# Test for Transcript Dataclass
def test_transcript_dataclass():
transcript = Transcript(
chrom="chr1",
start=100,
end=200,
clear_name="test_transcript",
mapping_quality=60,
strand="+",
umi="ATGC"
chrom="chr1", start=100, end=200, clear_name="test_transcript", mapping_quality=60, strand="+", umi="ATGC"
)

assert transcript.chrom == "chr1"
Expand All @@ -61,10 +58,12 @@ def test_transcript_dataclass():
assert transcript.strand == "+"
assert transcript.umi == "ATGC"


# Test for compute_unique_umis
def mock_group_umi_func(umis: List[str], mismatches: int) -> List[str]:
return umis[:1] # Simplified mock implementation for testing


def test_compute_unique_umis():
transcripts = [
Transcript("chr1", 100, 200, "t1", 60, "+", "UMI1"),
Expand All @@ -79,6 +78,7 @@ def test_compute_unique_umis():
assert len(unique_transcripts) == 1
assert unique_transcripts[0].umi == "UMI1"


# Test for createDataset with mocked dependencies
def test_create_dataset(tmp_path, monkeypatch, mock_bam_file, mock_gff_file):
# Mock inputs
Expand All @@ -88,10 +88,7 @@ def test_create_dataset(tmp_path, monkeypatch, mock_bam_file, mock_gff_file):
t1 = Transcript("chr1", 100, 200, "t1", 60, "+", "UMI1")
t2 = Transcript("chr2", 300, 400, "t2", 60, "-", "UMI2")
# Mock parse_unique_events
mock_parse_unique_events = Mock(return_value=[
("gene1", {(10, 10): [t1, t2]}),
("gene2", {(20, 20): [t1, t2]})
])
mock_parse_unique_events = Mock(return_value=[("gene1", {(10, 10): [t1, t2]}), ("gene2", {(20, 20): [t1, t2]})])
monkeypatch.setattr("stpipeline.common.dataset.parse_unique_events", mock_parse_unique_events)

# Mock dedup_hierarchical
Expand All @@ -107,8 +104,8 @@ def test_create_dataset(tmp_path, monkeypatch, mock_bam_file, mock_gff_file):
umi_counting_offset=10,
disable_umi=False,
output_template="output",
verbose=False
verbose=False,
)

assert stats["genes_found"] == 2
assert stats["reads_after_duplicates_removal"] == 2
assert stats["reads_after_duplicates_removal"] == 2
9 changes: 7 additions & 2 deletions tests/fastq_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#! /usr/bin/env python
"""
"""
Unit-test the package fastq_utils
"""
import pytest
Expand All @@ -8,9 +8,10 @@
quality_trim_index,
trim_quality,
check_umi_template,
has_sufficient_content
has_sufficient_content,
)


# Test for remove_adaptor
def test_remove_adaptor():
sequence = "AGCTTAGCTTAGCTA"
Expand All @@ -30,6 +31,7 @@ def test_remove_adaptor():
assert trimmed_seq == "AGCT"
assert trimmed_qual == "FFFF"


def test_quality_trim_index_basic():
sequence = "AGCTTAGCTTAGCTA"
quality = "FFFFFFFFFFFFFFF" # ASCII 'F' -> Phred score 40
Expand Down Expand Up @@ -93,6 +95,7 @@ def test_trim_quality_low_quality_g():
assert trimmed_seq == "AGCTTA"
assert trimmed_qual == "FFFFFF"


def test_trim_quality_short():
min_qual = 20
min_length = 10
Expand All @@ -102,6 +105,7 @@ def test_trim_quality_short():
assert trimmed_seq is None
assert trimmed_qual is None


# Test for check_umi_template
def test_check_umi_template():
umi = "ACGT1234"
Expand All @@ -111,6 +115,7 @@ def test_check_umi_template():
umi = "ACGT12"
assert check_umi_template(umi, template) is False


# Test for has_sufficient_content
def test_has_sufficient_content():
sequence = "ATATGGCCATAT"
Expand Down
13 changes: 8 additions & 5 deletions tests/filter_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#! /usr/bin/env python
"""
"""
Unit-test the package filter
"""
import pytest
import dnaio
from unittest.mock import Mock, patch
from stpipeline.common.filter import filter_input_data


def generate_test_fastq(filepath, records):
"""
Generates a mock FASTQ file for testing.
Expand All @@ -19,6 +20,7 @@ def generate_test_fastq(filepath, records):
for header, sequence, quality in records:
f.write(f"@{header}\n{sequence}\n+\n{quality}\n")


@pytest.fixture
def setup_fastq_files(tmp_path):
fw_records = [
Expand All @@ -27,15 +29,15 @@ def setup_fastq_files(tmp_path):
("read3", "GGGGGGGGGGGGGGGGGGGGGGGG", "IIIIIIIIIIIIIIIIIIIIIIII"),
("read4", "CCCCCCCCCCCCCCCCCCCCCCCC", "IIIIIIIIIIIIIIIIIIIIIIII"),
("read5", "ACTGACTGACTGACTGACTGACTG", "!!!!IIIIIIIIIIIIIIIIIIII"), # Low-quality UMI
("read6", "ACTGACTGACTGACTGACTGACTG", "!!!!!!!!!!!!!!IIIIIIIIII") # Too short after trimming
("read6", "ACTGACTGACTGACTGACTGACTG", "!!!!!!!!!!!!!!IIIIIIIIII"), # Too short after trimming
]
rv_records = [
("read1", "ACTGACTGACTGACTGACTGACTG", "IIIIIIIIIIIIIIIIIIIIIIII"),
("read2", "TTTTTTTTTTTTTTTTTTTTTTTT", "IIIIIIIIIIIIIIIIIIIIIIII"),
("read3", "GGGGGGGGGGGGGGGGGGGGGGGG", "IIIIIIIIIIIIIIIIIIIIIIII"),
("read4", "CCCCCCCCCCCCCCCCCCCCCCCC", "IIIIIIIIIIIIIIIIIIIIIIII"),
("read5", "ACTGACTGACTGACTGACTGACTG", "!!!!IIIIIIIIIIIIIIIIIIII"), # Low-quality UMI
("read6", "ACTGACTGACTGACTGACTGACTG", "!!!!!!!!!!!!!!IIIIIIIIII") # Too short after trimming
("read6", "ACTGACTGACTGACTGACTGACTG", "!!!!!!!!!!!!!!IIIIIIIIII"), # Too short after trimming
]
fw_file = tmp_path / "fw.fastq"
rv_file = tmp_path / "rv.fastq"
Expand All @@ -45,6 +47,7 @@ def setup_fastq_files(tmp_path):

return str(fw_file), str(rv_file)


@patch("stpipeline.common.filter.pysam.AlignmentFile")
def test_filter_input_data(mock_alignment_file, setup_fastq_files, tmp_path):
fw_file, rv_file = setup_fastq_files
Expand Down Expand Up @@ -78,9 +81,9 @@ def test_filter_input_data(mock_alignment_file, setup_fastq_files, tmp_path):
adaptor_missmatches=2,
overhang=2,
disable_umi=False,
disable_barcode=False
disable_barcode=False,
)

assert total_reads == 6
assert remaining_reads < total_reads
mock_alignment_file.assert_called_once_with(str(out_file), "wb")
mock_alignment_file.assert_called_once_with(str(out_file), "wb")
Loading

0 comments on commit a3a7d92

Please sign in to comment.