diff --git a/tests/eli5_qaconfig_test.py b/tests/eli5_qaconfig_test.py index 851c275..f6a3c3d 100644 --- a/tests/eli5_qaconfig_test.py +++ b/tests/eli5_qaconfig_test.py @@ -27,9 +27,11 @@ def step3(self): def step4(self): self.parser.save - self.output_path = os.path.join(self.output_dir, "ELI5_val_QAConfig_translated_ru.json") + self.output_path = os.path.join(self.output_dir, "ELI5_val_QAConfig.json") + self.output_path_translated = os.path.join(self.output_dir, "ELI5_val_QAConfig_translated_ru.json") self.assertTrue(os.path.exists(self.output_path), f"File '{self.output_path}' does not exist") + self.assertTrue(os.path.exists(self.output_path_translated), f"File '{self.output_path_translated}' does not exist") def step5(self): try: @@ -45,6 +47,8 @@ def step6(self): def step7(self): if os.path.exists(self.output_path): os.remove(self.output_path) + if os.path.exists(self.output_path_translated): + os.remove(self.output_path_translated) def _steps(self): for name in dir(self): # dir() result is implicitly sorted diff --git a/tests/eli5_test.py b/tests/eli5_test.py index 0deebb6..9db66d1 100644 --- a/tests/eli5_test.py +++ b/tests/eli5_test.py @@ -28,9 +28,11 @@ def step3(self): def step4(self): self.parser.save - self.output_path = os.path.join(self.output_dir, "ELI5_val_translated_de.json") + self.output_path = os.path.join(self.output_dir, "ELI5_val.json") + self.output_path_translated = os.path.join(self.output_dir, "ELI5_val_translated_de.json") self.assertTrue(os.path.exists(self.output_path), f"File '{self.output_path}' does not exist") + self.assertTrue(os.path.exists(self.output_path_translated), f"File '{self.output_path_translated}' does not exist") def step5(self): try: @@ -46,6 +48,8 @@ def step6(self): def step7(self): if os.path.exists(self.output_path): os.remove(self.output_path) + if os.path.exists(self.output_path_translated): + os.remove(self.output_path_translated) def _steps(self): for name in dir(self): # dir() result is implicitly sorted diff --git a/translator/data_parser.py b/translator/data_parser.py index 1a1b465..0f64755 100644 --- a/translator/data_parser.py +++ b/translator/data_parser.py @@ -1,3 +1,4 @@ +import math import re import json import os @@ -169,12 +170,13 @@ def multithread_list_str_translate(self, list_str: List[str], futures = [] finished_task = 0 manager = multiprocessing.Manager() + lock = manager.Lock() def callback_list_done(future): nonlocal translated_list_data nonlocal finished_task nonlocal manager - lock = manager.Lock() + nonlocal lock if not future.exception(): with lock: translated_list_data.append(future.result()) @@ -194,18 +196,8 @@ def callback_list_done(future): } futures.append(future_dict) - # Progress bar - # desc = f"Translating sub-list of {field_name} of chunk {progress_idx} sub-chunk" - # progress_bar = tqdm(total=len(futures), desc=desc, colour="red") - # # Manually refresh the progress bar to display it - # progress_bar.refresh() - - tmp_finished_task = 0 # Wait for all threads to complete while finished_task < len(futures): - # if finished_task != tmp_finished_task: - # progress_bar.update(1) - # tmp_finished_task = finished_task for future_dict in futures: # If exception occurs in one of the thread, restart the thread with its specific chunk if future_dict['future'].exception(): @@ -310,20 +302,27 @@ def translate_converted(self, range(0, len(converted_data), self.max_example_per_thread)] tqdm.write(f"\n Data too large, splitting data into {num_threads} chunk, each chunk is {len(chunks[0])}" f" Processing with multithread...\n") + + # Progress bar + desc = "Translating total converted large chunk data" if large_chunk else "Translating total converted data" + progress_bar = tqdm(total=math.ceil(num_threads), desc=desc) + with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] finished_task = 0 manager = multiprocessing.Manager() + lock = manager.Lock() def callback_done(future): nonlocal translated_data nonlocal finished_task - nonlocal manager - lock = manager.Lock() + nonlocal lock + nonlocal progress_bar if not future.exception(): with lock: translated_data += future.result() finished_task += 1 + progress_bar.update(1) tqdm.write("\nTask finished, adding translated data to result...\n") else: tqdm.write(f"\nTask failed with the following error: {future.exception()}." @@ -342,18 +341,8 @@ def callback_done(future): "idx": idx} futures.append(future_dict) - # Progress bar - desc = "Translating total converted large chunk data" if large_chunk else "Translating total converted data" - progress_bar = tqdm(total=len(futures), desc=desc) - # Manually refresh the progress bar to display it - progress_bar.refresh() - - tmp_finished_task = 0 # Wait for all threads to complete while finished_task < len(futures): - if tmp_finished_task != finished_task: - progress_bar.update(1) - tmp_finished_task = finished_task for future_dict in futures: # If exception occurs in one of the thread, restart the thread with its specific chunk if future_dict['future'].exception():