Skip to content

Commit

Permalink
feat, chore, fix: add global fail cache, add max retries, improve gro…
Browse files Browse the repository at this point in the history
…q prompt acc, fix tqdm AttributeError
  • Loading branch information
vTuanpham committed Sep 17, 2024
1 parent fefd981 commit b6e9891
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 57 deletions.
51 changes: 42 additions & 9 deletions providers/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
)

try:
from .utils import hash_input
from .utils import hash_input, pop_half_dict
except ImportError:
from utils import hash_input
from utils import hash_input, pop_half_dict


# Cache the fail prompt to avoid running translation again for subsequent calls
GLOBAL_CACHE_FAIL_PROMPT = {}
GLOBAL_MAX_LIST_RETRIES = 20 # Global max retries for list translation
GLOBAL_MAX_STRING_RETRIES = 10 # Global max retries for string translation


class Provider(ABC):
Expand Down Expand Up @@ -58,8 +64,8 @@ def __get_hashable_key(self, input_data: Union[str, List[str]], src: str, dest:

return f"{src}-{dest}-{hash_input(input_data, hash=False)}-{fail_translation_code}"

@cached(max_size=10000, thread_safe=False, custom_key_maker=__get_hashable_key, algorithm=CachingAlgorithmFlag.LRU)
@retry(stop=(stop_after_attempt(6) | stop_after_delay(120)), wait=wait_random_exponential(multiplier=1, max=30), reraise=True)
@cached(max_size=5000, thread_safe=False, custom_key_maker=__get_hashable_key, algorithm=CachingAlgorithmFlag.LRU, ttl=7200)
@retry(stop=(stop_after_attempt(max(GLOBAL_MAX_LIST_RETRIES, GLOBAL_MAX_STRING_RETRIES)) | stop_after_delay(180)), wait=wait_random_exponential(multiplier=1, max=30), reraise=True)
def translate(self, input_data: Union[str, List[str]], src: str, dest: str, fail_translation_code: str="P1OP1_F") -> Union[str, List[str]]:
"""
Translates the input data from the source language to the destination language using the assigned translator object.
Expand All @@ -77,6 +83,8 @@ def translate(self, input_data: Union[str, List[str]], src: str, dest: str, fail
- The translation is performed by calling the _do_translate() method.
"""

global GLOBAL_CACHE_FAIL_PROMPT

# Type check for input_data
if not isinstance(input_data, (str, list)):
raise TypeError(f"input_data must be of type str or List[str], not {type(input_data).__name__}")
Expand All @@ -87,11 +95,36 @@ def translate(self, input_data: Union[str, List[str]], src: str, dest: str, fail
# Ensure the translator is set
assert self.translator, "Please assign the translator object instance to self.translator"

# Perform the translation
translated_instance = self._do_translate(input_data,
src=src,
dest=dest,
fail_translation_code=fail_translation_code)
if len(GLOBAL_CACHE_FAIL_PROMPT) > 5000:
_, GLOBAL_CACHE_FAIL_PROMPT = pop_half_dict(GLOBAL_CACHE_FAIL_PROMPT)

parametrized_hash = hash_input(self.__get_hashable_key(input_data, src, dest, fail_translation_code))

try:
# Perform the translation
translated_instance = self._do_translate(input_data,
src=src,
dest=dest,
fail_translation_code=fail_translation_code)
if parametrized_hash in GLOBAL_CACHE_FAIL_PROMPT:
GLOBAL_CACHE_FAIL_PROMPT.pop(parametrized_hash)
except Exception as e:
# Check if the exception is unavoidable by matching the prompt with the cache fail prompt key
if parametrized_hash in GLOBAL_CACHE_FAIL_PROMPT:
if isinstance(input_data, list) and GLOBAL_CACHE_FAIL_PROMPT[parametrized_hash] >= GLOBAL_MAX_LIST_RETRIES:
print(f"\nUnavoidable exception: {e}\nGlobal max retries reached for list translation")
return [fail_translation_code, fail_translation_code]
elif isinstance(input_data, str) and GLOBAL_CACHE_FAIL_PROMPT[parametrized_hash] >= GLOBAL_MAX_STRING_RETRIES:
print(f"\nUnavoidable exception: {e}\nGlobal max retries reached for string translation")
return fail_translation_code
else:
GLOBAL_CACHE_FAIL_PROMPT[parametrized_hash] += 1
else:
GLOBAL_CACHE_FAIL_PROMPT[parametrized_hash] = 1

print(f"\nCurrent global fail cache: {GLOBAL_CACHE_FAIL_PROMPT}\n")
raise e


assert type(input_data) == type(translated_instance),\
f" The function self._do_translate() return mismatch datatype from the input_data," \
Expand Down
Loading

0 comments on commit b6e9891

Please sign in to comment.