From b48dec847caec065fd8bd02736386338902483d2 Mon Sep 17 00:00:00 2001 From: Tohru <65994850+Tohrusky@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:18:26 +0800 Subject: [PATCH] feat: use pysub2 to gen ass subtitle (#4) * remove pkg pysrt --- .github/workflows/CI-test.yml | 2 +- .gitignore | 4 +- README.md | 29 +++++----- pyproject.toml | 4 +- tests/test_srt.py | 36 ------------ tests/test_sub.py | 36 ++++++++++++ yuisub/__init__.py | 2 +- yuisub/__main__.py | 32 +++++----- yuisub/a2t.py | 45 ++------------- yuisub/srt.py | 80 ------------------------- yuisub/sub.py | 106 ++++++++++++++++++++++++++++++++++ 11 files changed, 186 insertions(+), 190 deletions(-) delete mode 100644 tests/test_srt.py create mode 100644 tests/test_sub.py delete mode 100644 yuisub/srt.py create mode 100644 yuisub/sub.py diff --git a/.github/workflows/CI-test.yml b/.github/workflows/CI-test.yml index 56a1398..f429331 100644 --- a/.github/workflows/CI-test.yml +++ b/.github/workflows/CI-test.yml @@ -48,7 +48,7 @@ jobs: - name: Test run: | pip install numpy==1.26.4 - pip install pre-commit pytest mypy ruff types-requests pytest-cov coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysrt + pip install pre-commit pytest mypy ruff types-requests pytest-cov coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysubs2 make lint make test diff --git a/.gitignore b/.gitignore index ed51e75..00c0b3f 100644 --- a/.gitignore +++ b/.gitignore @@ -161,5 +161,7 @@ cython_debug/ .idea/ /.ruff_cache/ -/assets/*.srt +/assets/*.mkv /assets/*.mp3 +/assets/*.srt +/assets/*.ass diff --git a/README.md b/README.md index 6a78ddc..b30684b 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ pip install openai-whisper ### Command Line Usage -`yuisub` can be used from the command line to generate bilingual SRT files. Here's how to use it: +`yuisub` can be used from the command line to generate bilingual ASS files. Here's how to use it: ```bash yuisub -h # Displays help message @@ -38,29 +38,32 @@ yuisub -h # Displays help message ### Example ```python3 -from yuisub import bilingual, from_file +from yuisub import translate, bilingual, load from yuisub.a2t import WhisperModel -# srt from audio +# sub from audio model = WhisperModel(name="medium", device="cuda") -segs = model.transcribe(audio="path/to/audio.mp3") -srt = model.gen_srt(segs) +sub = model.transcribe(audio="path/to/audio.mp3") -# srt from file -# srt = from_file("path/to/input.srt") +# sub from file +# sub = from_file("path/to/input.srt") -# Generate bilingual SRT -srt_zh, srt_bilingual = bilingual( - srt=srt, +# generate bilingual subtitle +sub_zh = translate( + sub=sub, model="gpt_model_name", api_key="your_openai_api_key", base_url="api_url", bangumi_url="https://bangumi.tv/subject/424883/" ) +sub_bilingual = bilingual( + sub_origin=sub, + sub_zh=sub_zh +) -# Save the SRT files -srt_zh.save("path/to/output.zh.srt") -srt_bilingual.save("path/to/output.bilingual.srt") +# save the ASS files +sub_zh.save("path/to/output.zh.ass") +sub_bilingual.save("path/to/output.bilingual.ass") ``` ### License diff --git a/pyproject.toml b/pyproject.toml index 3208479..522e878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,14 +39,14 @@ license = "GPL-3.0-only" name = "yuisub" readme = "README.md" repository = "https://github.com/TensoRaws/yuisub" -version = "0.0.3" +version = "0.0.4" # Requirements [tool.poetry.dependencies] beautifulsoup4 = "*" openai = "*" pydantic = "*" -pysrt = "*" +pysubs2 = "*" python = "^3.9" requests = "*" tenacity = "*" diff --git a/tests/test_srt.py b/tests/test_srt.py deleted file mode 100644 index 4197002..0000000 --- a/tests/test_srt.py +++ /dev/null @@ -1,36 +0,0 @@ -import os - -import pytest - -from tests import util -from yuisub.a2t import WhisperModel -from yuisub.srt import bilingual, from_file - - -def test_srt() -> None: - srt = from_file(util.TEST_ENG_SRT) - srt.save(util.projectPATH / "assets" / "test.en.srt") - - -def test_srt_audio() -> None: - model = WhisperModel(name=util.MODEL_NAME, device=util.DEVICE) - - segs = model.transcribe(audio=str(util.TEST_AUDIO)) - srt = model.gen_srt(segs) - srt.save(util.projectPATH / "assets" / "test.audio.srt") - - -@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI") -def test_bilingual() -> None: - srt = from_file(util.TEST_ENG_SRT) - - srt_zh, srt_zh_jp = bilingual( - srt=srt, - model=util.OPENAI_MODEL, - api_key=util.OPENAI_API_KEY, - base_url=util.OPENAI_BASE_URL, - bangumi_url=util.BANGUMI_URL, - ) - - srt_zh.save(util.projectPATH / "assets" / "test.zh.srt") - srt_zh_jp.save(util.projectPATH / "assets" / "test.bilingual.srt") diff --git a/tests/test_sub.py b/tests/test_sub.py new file mode 100644 index 0000000..04bc644 --- /dev/null +++ b/tests/test_sub.py @@ -0,0 +1,36 @@ +import os + +import pytest + +from tests import util +from yuisub.a2t import WhisperModel +from yuisub.sub import bilingual, load, translate + + +def test_sub() -> None: + sub = load(util.TEST_ENG_SRT) + sub.save(util.projectPATH / "assets" / "test.en.ass") + + +def test_audio() -> None: + model = WhisperModel(name=util.MODEL_NAME, device=util.DEVICE) + + sub = model.transcribe(audio=str(util.TEST_AUDIO)) + sub.save(util.projectPATH / "assets" / "test.audio.ass") + + +@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI") +def test_bilingual() -> None: + sub = load(util.TEST_ENG_SRT) + + sub_zh = translate( + sub=sub, + model=util.OPENAI_MODEL, + api_key=util.OPENAI_API_KEY, + base_url=util.OPENAI_BASE_URL, + bangumi_url=util.BANGUMI_URL, + ) + sub_bilingual = bilingual(sub_origin=sub, sub_zh=sub_zh) + + sub_zh.save(util.projectPATH / "assets" / "test.zh.ass") + sub_bilingual.save(util.projectPATH / "assets" / "test.bilingual.ass") diff --git a/yuisub/__init__.py b/yuisub/__init__.py index 4488d00..93f55ed 100644 --- a/yuisub/__init__.py +++ b/yuisub/__init__.py @@ -1,4 +1,4 @@ from yuisub.bangumi import bangumi # noqa: F401 from yuisub.llm import Translator # noqa: F401 from yuisub.prompt import ORIGIN, ZH # noqa: F401 -from yuisub.srt import bilingual, from_file # noqa: F401 +from yuisub.sub import bilingual, load, translate # noqa: F401 diff --git a/yuisub/__main__.py b/yuisub/__main__.py index 4d58c1f..07faa35 100644 --- a/yuisub/__main__.py +++ b/yuisub/__main__.py @@ -1,19 +1,19 @@ import argparse import sys -from yuisub.srt import bilingual, from_file +from yuisub.sub import bilingual, load, translate # ffmpeg -i test.mkv -c:a mp3 -map 0:a:0 test.mp3 # ffmpeg -i test.mkv -map 0:s:0 eng.srt parser = argparse.ArgumentParser() -parser.description = "Generate bilingual SRT files from audio or SRT input." +parser.description = "Generate Bilingual Subtitle from audio or subtitle file" # input parser.add_argument("-a", "--AUDIO", type=str, help="Path to the audio file", required=False) -parser.add_argument("-s", "--SRT", type=str, help="Path to the input SRT file", required=False) -# srt output -parser.add_argument("-oz", "--OUTPUT_ZH", type=str, help="Path to save the Chinese SRT file", required=False) -parser.add_argument("-ob", "--OUTPUT_BILINGUAL", type=str, help="Path to save the bilingual SRT file", required=False) +parser.add_argument("-s", "--SUB", type=str, help="Path to the input Subtitle file", required=False) +# subtitle output +parser.add_argument("-oz", "--OUTPUT_ZH", type=str, help="Path to save the Chinese ASS file", required=False) +parser.add_argument("-ob", "--OUTPUT_BILINGUAL", type=str, help="Path to save the bilingual ASS file", required=False) # openai gpt parser.add_argument("-om", "--OPENAI_MODEL", type=str, help="Openai model name", required=True) parser.add_argument("-api", "--OPENAI_API_KEY", type=str, help="Openai API key", required=True) @@ -28,8 +28,8 @@ def main() -> None: - if args.AUDIO and args.SRT: - raise ValueError("Please provide only one input file, either audio or SRT.") + if args.AUDIO and args.SUB: + raise ValueError("Please provide only one input file, either audio or subtitle file") if not args.OUTPUT_ZH and not args.OUTPUT_BILINGUAL: raise ValueError("Please provide output paths for the subtitles.") @@ -53,26 +53,26 @@ def main() -> None: model = WhisperModel(name=_MODEL, device=_DEVICE) - segs = model.transcribe(audio=args.AUDIO) - - srt = model.gen_srt(segs=segs) + sub = model.transcribe(audio=args.AUDIO) else: - srt = from_file(args.SRT) + sub = load(args.SUB) - srt_zh, srt_bilingual = bilingual( - srt=srt, + sub_zh = translate( + sub=sub, model=args.OPENAI_MODEL, api_key=args.OPENAI_API_KEY, base_url=args.OPENAI_BASE_URL, bangumi_url=args.BANGUMI_URL, ) + sub_bilingual = bilingual(sub_origin=sub, sub_zh=sub_zh) + if args.OUTPUT_ZH: - srt_zh.save(args.OUTPUT_ZH) + sub_zh.save(args.OUTPUT_ZH) if args.OUTPUT_BILINGUAL: - srt_bilingual.save(args.OUTPUT_BILINGUAL) + sub_bilingual.save(args.OUTPUT_BILINGUAL) if __name__ == "__main__": diff --git a/yuisub/a2t.py b/yuisub/a2t.py index ebb537e..1e65781 100644 --- a/yuisub/a2t.py +++ b/yuisub/a2t.py @@ -1,24 +1,10 @@ -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np -import pysrt +import pysubs2 import torch import whisper -from pydantic import BaseModel -from pysrt import SubRipFile - - -class Segment(BaseModel): - id: int - seek: int - start: float - end: float - text: str - tokens: List[int] - temperature: float - avg_logprob: float - compression_ratio: float - no_speech_prob: float +from pysubs2 import SSAFile class WhisperModel: @@ -40,7 +26,7 @@ def transcribe( word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", - ) -> List[Segment]: + ) -> SSAFile: result = self.model.transcribe( audio=audio, verbose=verbose, @@ -54,25 +40,4 @@ def transcribe( prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, ) - segments: List[Segment] = [Segment(**seg) for seg in result["segments"]] - return segments - - @staticmethod - def gen_srt(segs: List[Segment]) -> SubRipFile: - line_out: str = "" - for s in segs: - segment_id = s.id + 1 - start_time = format_time(s.start) - end_time = format_time(s.end) - text = s.text - - line_out += f"{segment_id}\n{start_time} --> {end_time}\n{text.lstrip()}\n\n" - subs = pysrt.from_string(line_out) - return subs - - -def format_time(seconds: float) -> str: - minutes, seconds = divmod(seconds, 60) - hours, minutes = divmod(minutes, 60) - milliseconds = (seconds - int(seconds)) * 1000 - return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}" + return pysubs2.load_from_whisper(result) diff --git a/yuisub/srt.py b/yuisub/srt.py deleted file mode 100644 index 082faf6..0000000 --- a/yuisub/srt.py +++ /dev/null @@ -1,80 +0,0 @@ -import asyncio -from copy import deepcopy -from pathlib import Path -from typing import Any, List, Tuple - -import pysrt -from bs4 import BeautifulSoup -from pysrt import SubRipFile -from tenacity import retry, stop_after_attempt, wait_random - -from yuisub.llm import Translator -from yuisub.prompt import ORIGIN - - -def from_file(srt_path: Path | str, encoding: Any = None) -> SubRipFile: - """ - Load srt file from file path, auto remove html tags - - :param srt_path: srt file path - :param encoding: srt file encoding, default is utf-8 - :return: - """ - srt = pysrt.open(path=str(srt_path), encoding=encoding) - for sub in srt: - try: - soup = BeautifulSoup(sub.text, "html.parser") - text = soup.get_text() - sub.text = text - except Exception as e: - print(e) - print(sub.text) - return srt - - -@retry(wait=wait_random(min=3, max=5), stop=stop_after_attempt(5)) -def bilingual( - srt: SubRipFile, model: str, api_key: str, base_url: str, bangumi_url: str | None = None -) -> Tuple[SubRipFile, SubRipFile]: - """ - Generate bilingual srt file, first return is the Chinese subtitle, second return is the Bilingual subtitle - - :param srt: origin srt file - :param model: llm model - :param api_key: llm api_key - :param base_url: llm base_url - :param bangumi_url: anime bangumi url - :return: - """ - - # pending translation - trans_list: List[str] = [s.text for s in srt] - - tr = Translator(model=model, api_key=api_key, base_url=base_url, bangumi_url=bangumi_url) - print(tr.system_prompt) - - async def translate(index: int) -> None: - nonlocal trans_list - - translated_text = await tr.ask(ORIGIN(origin=trans_list[index])) - print(f"Translated: {trans_list[index]} ---> {translated_text.zh}") - trans_list[index] = translated_text.zh - - # wait for all tasks to finish - async def wait_tasks() -> None: - tasks = [translate(index) for index in range(len(srt))] - await asyncio.gather(*tasks) - - asyncio.run(wait_tasks()) - - # generate bilingual srt - srt_zh: SubRipFile = deepcopy(srt) - srt_bilingual: SubRipFile = deepcopy(srt) - for i, s in enumerate(srt): - text_zh = trans_list[i] - text_bilingual = trans_list[i] + "\n" + s.text - - srt_zh[i].text = text_zh - srt_bilingual[i].text = text_bilingual - - return srt_zh, srt_bilingual diff --git a/yuisub/sub.py b/yuisub/sub.py new file mode 100644 index 0000000..6765d22 --- /dev/null +++ b/yuisub/sub.py @@ -0,0 +1,106 @@ +import asyncio +from copy import deepcopy +from pathlib import Path +from typing import List + +import pysubs2 +from pysubs2 import Alignment, Color, SSAFile, SSAStyle +from tenacity import retry, stop_after_attempt, wait_random + +from yuisub.llm import Translator +from yuisub.prompt import ORIGIN + + +def load(sub_path: Path | str, encoding: str = "utf-8") -> SSAFile: + """ + Load subtitle from file path, default encoding is utf-8 and remove style + + :param sub_path: subtitle file path + :param encoding: subtitle file encoding, default is utf-8 + :return: + """ + sub = pysubs2.load(str(sub_path), encoding=encoding) + return sub + + +@retry(wait=wait_random(min=3, max=5), stop=stop_after_attempt(5)) +def translate(sub: SSAFile, model: str, api_key: str, base_url: str, bangumi_url: str | None = None) -> SSAFile: + """ + Translate subtitle file to Chinese + + :param sub: origin subtitle + :param model: llm model + :param api_key: llm api_key + :param base_url: llm base_url + :param bangumi_url: anime bangumi url + :return: + """ + + # pending translation + trans_list: List[str] = [s.text for s in sub] + + tr = Translator(model=model, api_key=api_key, base_url=base_url, bangumi_url=bangumi_url) + print(tr.system_prompt) + + async def _translate(index: int) -> None: + nonlocal trans_list + + translated_text = await tr.ask(ORIGIN(origin=trans_list[index])) + print(f"Translated: {trans_list[index]} ---> {translated_text.zh}") + trans_list[index] = translated_text.zh + + # wait for all tasks to finish + async def _wait_tasks() -> None: + tasks = [_translate(index) for index in range(len(sub))] + await asyncio.gather(*tasks) + + asyncio.run(_wait_tasks()) + + # generate Chinese subtitle + sub_zh = deepcopy(sub) + for i, _ in enumerate(sub): + sub_zh[i].text = trans_list[i] + + return sub_zh + + +def bilingual(sub_origin: SSAFile, sub_zh: SSAFile) -> SSAFile: + """ + Generate bilingual subtitle file + + :param sub_origin: Origin subtitle + :param sub_zh: Chinese subtitle + :return: + """ + + # generate bilingual subtitle + sub_bilingual = SSAFile() + sub_bilingual.styles = { + "zh": SSAStyle( + alignment=Alignment.BOTTOM_CENTER, + primarycolor=Color(255, 192, 203), + fontsize=16, + fontname="Microsoft YaHei", + bold=True, + shadow=0, + outline=0.2, + ), + "origin": SSAStyle( + alignment=Alignment.BOTTOM_CENTER, + primarycolor=pysubs2.Color(249, 246, 240), + fontsize=12, + fontname="Microsoft YaHei", + shadow=0, + outline=0.5, + ), + } + + for e in sub_origin: + e.style = "origin" + sub_bilingual.append(e) + + for e in sub_zh: + e.style = "zh" + sub_bilingual.append(e) + + return sub_bilingual