diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..a07b7bf --- /dev/null +++ b/.coveragerc @@ -0,0 +1,14 @@ +[run] +source_pkgs = aligner +omit = + *tmp* + */run_tests.py + */tests/* + */__main__.py + +[report] +precision = 2 +exclude_lines = + pragma: no cover + if 0: + if __name__ == .__main__.: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ebc6b46..2fde716 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,19 +13,38 @@ jobs: run: shell: bash -l {0} steps: - - name: Checkout repository - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 with: python-version: "3.8" cache: "pip" + + - uses: FedericoCarboni/setup-ffmpeg@v2 + - name: Install dependencies and package - run: pip install -e . mypy - - name: Minimal test, --help should work - run: ctc-segmenter --help - - name: Code quality test, mypy should pass + run: pip install -e . mypy coverage + + - name: Minimal code quality test, mypy should pass run: mypy aligner + + - uses: actions/cache@v4 + with: + path: /home/runner/.cache/torch + key: torch-cache + + - name: Run unit tests + run: | + coverage run -m unittest discover aligner.tests -v + coverage xml + + - name: Upload coverage report to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + + - run: coverage report + - name: Make sure the CLI stays fast id: cli-load-time run: | @@ -44,6 +63,7 @@ jobs: echo "Please run 'PYTHONPROFILEIMPORTTIME=1 ctc-segmenter -h 2> importtime.txt; tuna importtime.txt' and tuck away expensive imports so that the CLI doesn't load them until it uses them."; \ false; \ fi + - name: Report help speed in PR if: github.event_name == 'pull_request' uses: mshick/add-pr-comment@v2 diff --git a/.gitignore b/.gitignore index 46eeb8a..7f381c5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ *.egg-info __pycache__ *.wav -*.TextGrid \ No newline at end of file +*.TextGrid.coverage diff --git a/aligner/__main__.py b/aligner/__main__.py new file mode 100644 index 0000000..f8cd76f --- /dev/null +++ b/aligner/__main__.py @@ -0,0 +1,3 @@ +from .cli import app + +app() diff --git a/aligner/cli.py b/aligner/cli.py index 214accc..37d4268 100644 --- a/aligner/cli.py +++ b/aligner/cli.py @@ -134,7 +134,14 @@ def align_single( ), debug: bool = typer.Option(False, help="Print debug statements"), ): - print("loading model...") + # Do fast error checking before loading expensive dependencies + sentence_list = read_text(text_path) + if not sentence_list or not any(sentence_list): + raise typer.BadParameter( + f"TEXT_PATH file '{text_path}' is empty; it should contain sentences to align.", + ) + + print("loading pytorch...") import torch import torchaudio @@ -155,7 +162,6 @@ def align_single( audio_path = Path(fn + f"-{sample_rate}-mono" + ext) torchaudio.save(str(audio_path), wav, sample_rate) print("processing text") - sentence_list = read_text(text_path) transducer = create_transducer("".join(sentence_list), labels, debug) text_hash = TextHash(sentence_list, transducer) print("performing alignment") @@ -180,7 +186,3 @@ def align_single( tg_path = audio_path.with_suffix(".TextGrid") print(f"writing file to {tg_path}") tg.to_file(tg_path) - - -if __name__ == "__main__": - align_single() diff --git a/aligner/heavy.py b/aligner/heavy.py index b99a4b7..7d597f8 100644 --- a/aligner/heavy.py +++ b/aligner/heavy.py @@ -79,7 +79,7 @@ def compute_alignments( token_index += 1 frames.append(Frame(token_index, i, score)) prev_hyp = ali - words_to_match = [v | {"key": k} for k, v in transcript_hash.items() if "w" in k] + words_to_match = [{**v, "key": k} for k, v in transcript_hash.items() if "w" in k] i1, i2 = 0, 0 segments = [] while i1 < len(frames): diff --git a/aligner/tests/__init__.py b/aligner/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aligner/tests/test_cli.py b/aligner/tests/test_cli.py new file mode 100644 index 0000000..a957c84 --- /dev/null +++ b/aligner/tests/test_cli.py @@ -0,0 +1,114 @@ +""" +Run wav2vec2aligner unit tests. +How to run this test suite: +If you installed wav2vec2aligner: + python -m unittest aligner.tests.test_cli +If you installed everyvoice: + python -m unittest everyvoice.model.aligner.wav2vec2aligner.aligner.tests.test_cli +""" + +import os +import subprocess +import tempfile +from pathlib import Path +from unittest import TestCase + +from typer.testing import CliRunner + +from ..classes import Segment +from ..cli import app, complete_path + + +class CLITest(TestCase): + def setUp(self) -> None: + self.runner = CliRunner() + + def test_main_help(self): + for help in "-h", "--help": + with self.subTest(help=help): + result = self.runner.invoke(app, [help]) + self.assertEqual(result.exit_code, 0) + self.assertIn("align", result.stdout) + self.assertIn("extract", result.stdout) + + def test_sub_help(self): + for cmd in "align", "extract": + for help in "-h", "--help": + with self.subTest(cmd=cmd, help=help): + result = self.runner.invoke(app, [cmd, help]) + self.assertEqual(result.exit_code, 0) + self.assertIn("Usage:", result.stdout) + self.assertIn(cmd, result.stdout) + + def test_align_empty_file(self): + with self.subTest("empty file"): + result = self.runner.invoke(app, ["align", os.devnull, os.devnull]) + self.assertNotEqual(result.exit_code, 0) + self.assertIn("is empty", result.stdout) + + with self.subTest("file with only empty lines"): + with tempfile.TemporaryDirectory() as tmpdir: + textfile = os.path.join(tmpdir, "emptylines.txt") + with open(textfile, "w", encoding="utf8") as f: + f.write("\n \n \n") + result = self.runner.invoke(app, ["align", textfile, os.devnull]) + self.assertNotEqual(result.exit_code, 0) + self.assertIn("is empty", result.stdout) + + def fetch_ras_test_file(self, filename, outputdir): + from urllib.request import Request, urlopen + + repo, path = "https://github.com/ReadAlongs/Studio/", "/test/data/" + request = Request(repo + "raw/refs/heads/main" + path + filename) + request.add_header("Referer", repo + "blob/main" + path + filename) + response = urlopen(request) + with open(os.path.join(outputdir, filename), "wb") as f: + f.write(response.read()) + + def test_align_something(self): + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + self.fetch_ras_test_file("ej-fra.txt", tmpdir) + txt = tmppath / "ej-fra.txt" + self.fetch_ras_test_file("ej-fra.m4a", tmpdir) + m4a = tmppath / "ej-fra.m4a" + wav = tmppath / "ej-fra.wav" + # Under most circumstances, align can take a .m4a input file, but not + # in CI. Since it's not a hard requirement, so just convert to .wav. + subprocess.run(["ffmpeg", "-i", m4a, wav], capture_output=True) + # os.system("ls -la " + tmpdir) + textgrid = tmppath / "ej-fra-16000.TextGrid" + wav_out = tmppath / "ej-fra-16000.wav" + + with self.subTest("ctc-segmenter align"): + result = self.runner.invoke(app, ["align", str(txt), str(wav)]) + if result.exit_code != 0: + os.system("ls -la " + tmpdir) + print(result.stdout) + self.assertEqual(result.exit_code, 0) + self.assertTrue(textgrid.exists()) + self.assertTrue(wav_out.exists()) + + with self.subTest("ctc-segmenter extract"): + result = self.runner.invoke( + app, ["extract", str(textgrid), str(wav_out), str(tmppath / "out")] + ) + if result.exit_code != 0: + print(result.stdout) + self.assertEqual(result.exit_code, 0) + self.assertTrue((tmppath / "out/metadata.psv").exists()) + with open(txt, encoding="utf8") as txt_f: + non_blank_line_count = sum(1 for line in txt_f if line.strip()) + for i in range(non_blank_line_count): + self.assertTrue((tmppath / f"out/wavs/segment{i}.wav")) + + +class MiscTests(TestCase): + def test_shell_complete(self): + self.assertEqual(complete_path(), []) + self.assertEqual(complete_path(None, None, None), []) + + def test_segment(self): + segment = Segment("text", 500, 700, 0.42) + self.assertEqual(len(segment), 200) + self.assertEqual(repr(segment), "text (0.42): [ 500, 700)") diff --git a/aligner/utils.py b/aligner/utils.py index 9c3282c..77d2591 100644 --- a/aligner/utils.py +++ b/aligner/utils.py @@ -32,7 +32,7 @@ def create_transducer(text, labels_dictionary, debug=False): if char not in allowable_chars and char not in fallback_mapping: fallback_mapping[char] = "" for k in fallback_mapping.keys(): - if debug: + if debug: # pragma: no cover print( f"Found {k} which is not modelled by Wav2Vec2; skipping for alignment" ) @@ -49,7 +49,7 @@ def create_transducer(text, labels_dictionary, debug=False): def read_text(text_path): - with open(text_path) as f: + with open(text_path, encoding="utf8") as f: return [x.strip() for x in f] diff --git a/requirements.txt b/requirements.txt index d609e27..39b44d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ -torch>=2.1.0 -torchaudio>=2.1.0 g2p>=1.0.20230417 +pydub>=0.23.1 pympi-ling -typer>=0.9.0 rich>=10.11.0 shellingham>=1.3.0 +soundfile>=0.10.2 +torch>=2.1.0 +torchaudio>=2.1.0 +typer>=0.9.0