Skip to content

Commit

Permalink
modified: yuisub/llm.py
Browse files Browse the repository at this point in the history
	modified:   yuisub/prompt.py
	modified:   yuisub/sub.py
  • Loading branch information
NULL204 committed Feb 6, 2025
1 parent d4a9078 commit 520d620
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 13 deletions.
12 changes: 6 additions & 6 deletions yuisub/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tenacity import retry, stop_after_attempt, wait_random

from yuisub.bangumi import BGM
from yuisub.prompt import ORIGIN, ZH, anime_prompt, summary_prompt
from yuisub.prompt import ZH, anime_prompt, summary_prompt


class Translator:
Expand All @@ -21,19 +21,19 @@ def __init__(
self.corner_case = True

@retry(wait=wait_random(min=3, max=5), stop=stop_after_attempt(5))
async def ask(self, question: ORIGIN) -> ZH:
async def ask(self, question: str) -> ZH:
if self.corner_case:
# blank question
if question.origin == "":
if question == "":
return ZH(zh="")

# too long question, return directly
if len(question.origin) > 100:
return ZH(zh=question.origin)
if len(question) > 100:
return ZH(zh=question)

messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": question.origin},
{"role": "user", "content": question},
]

try:
Expand Down
4 changes: 0 additions & 4 deletions yuisub/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
from yuisub.bangumi import BGM


class ORIGIN(BaseModel):
origin: str


class ZH(BaseModel):
zh: str

Expand Down
5 changes: 2 additions & 3 deletions yuisub/sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from yuisub.bangumi import bangumi
from yuisub.llm import Summarizer, Translator
from yuisub.prompt import ORIGIN

PRESET_STYLES: dict[str, SSAStyle] = {
"zh": SSAStyle(
Expand Down Expand Up @@ -115,7 +114,7 @@ async def translate(
print(summarizer.system_prompt)

# get summary
summary = await summarizer.ask(ORIGIN(origin="\n".join(trans_list)))
summary = await summarizer.ask("\n".join(trans_list))

# initialize translator
translator = Translator(
Expand All @@ -130,7 +129,7 @@ async def translate(
# create translate text task
async def _translate(index: int) -> None:
nonlocal trans_list
translated_text = await translator.ask(ORIGIN(origin=trans_list[index]))
translated_text = await translator.ask(trans_list[index])
print(f"Translated: {trans_list[index]} ---> {translated_text.zh}")
trans_list[index] = translated_text.zh

Expand Down

0 comments on commit 520d620

Please sign in to comment.