Skip to content

Commit

Permalink
fix, chore: fix tqdm bar for total chunk process, clean up parsed fil…
Browse files Browse the repository at this point in the history
…e in tests
  • Loading branch information
vTuanpham committed Dec 10, 2023
1 parent 078e32d commit 39172ba
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 25 deletions.
6 changes: 5 additions & 1 deletion tests/eli5_qaconfig_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def step3(self):
def step4(self):
self.parser.save

self.output_path = os.path.join(self.output_dir, "ELI5_val_QAConfig_translated_ru.json")
self.output_path = os.path.join(self.output_dir, "ELI5_val_QAConfig.json")
self.output_path_translated = os.path.join(self.output_dir, "ELI5_val_QAConfig_translated_ru.json")

self.assertTrue(os.path.exists(self.output_path), f"File '{self.output_path}' does not exist")
self.assertTrue(os.path.exists(self.output_path_translated), f"File '{self.output_path_translated}' does not exist")

def step5(self):
try:
Expand All @@ -45,6 +47,8 @@ def step6(self):
def step7(self):
if os.path.exists(self.output_path):
os.remove(self.output_path)
if os.path.exists(self.output_path_translated):
os.remove(self.output_path_translated)

def _steps(self):
for name in dir(self): # dir() result is implicitly sorted
Expand Down
6 changes: 5 additions & 1 deletion tests/eli5_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def step3(self):
def step4(self):
self.parser.save

self.output_path = os.path.join(self.output_dir, "ELI5_val_translated_de.json")
self.output_path = os.path.join(self.output_dir, "ELI5_val.json")
self.output_path_translated = os.path.join(self.output_dir, "ELI5_val_translated_de.json")

self.assertTrue(os.path.exists(self.output_path), f"File '{self.output_path}' does not exist")
self.assertTrue(os.path.exists(self.output_path_translated), f"File '{self.output_path_translated}' does not exist")

def step5(self):
try:
Expand All @@ -46,6 +48,8 @@ def step6(self):
def step7(self):
if os.path.exists(self.output_path):
os.remove(self.output_path)
if os.path.exists(self.output_path_translated):
os.remove(self.output_path_translated)

def _steps(self):
for name in dir(self): # dir() result is implicitly sorted
Expand Down
35 changes: 12 additions & 23 deletions translator/data_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import re
import json
import os
Expand Down Expand Up @@ -169,12 +170,13 @@ def multithread_list_str_translate(self, list_str: List[str],
futures = []
finished_task = 0
manager = multiprocessing.Manager()
lock = manager.Lock()

def callback_list_done(future):
nonlocal translated_list_data
nonlocal finished_task
nonlocal manager
lock = manager.Lock()
nonlocal lock
if not future.exception():
with lock:
translated_list_data.append(future.result())
Expand All @@ -194,18 +196,8 @@ def callback_list_done(future):
}
futures.append(future_dict)

# Progress bar
# desc = f"Translating sub-list of {field_name} of chunk {progress_idx} sub-chunk"
# progress_bar = tqdm(total=len(futures), desc=desc, colour="red")
# # Manually refresh the progress bar to display it
# progress_bar.refresh()

tmp_finished_task = 0
# Wait for all threads to complete
while finished_task < len(futures):
# if finished_task != tmp_finished_task:
# progress_bar.update(1)
# tmp_finished_task = finished_task
for future_dict in futures:
# If exception occurs in one of the thread, restart the thread with its specific chunk
if future_dict['future'].exception():
Expand Down Expand Up @@ -310,20 +302,27 @@ def translate_converted(self,
range(0, len(converted_data), self.max_example_per_thread)]
tqdm.write(f"\n Data too large, splitting data into {num_threads} chunk, each chunk is {len(chunks[0])}"
f" Processing with multithread...\n")

# Progress bar
desc = "Translating total converted large chunk data" if large_chunk else "Translating total converted data"
progress_bar = tqdm(total=math.ceil(num_threads), desc=desc)

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
finished_task = 0
manager = multiprocessing.Manager()
lock = manager.Lock()

def callback_done(future):
nonlocal translated_data
nonlocal finished_task
nonlocal manager
lock = manager.Lock()
nonlocal lock
nonlocal progress_bar
if not future.exception():
with lock:
translated_data += future.result()
finished_task += 1
progress_bar.update(1)
tqdm.write("\nTask finished, adding translated data to result...\n")
else:
tqdm.write(f"\nTask failed with the following error: {future.exception()}."
Expand All @@ -342,18 +341,8 @@ def callback_done(future):
"idx": idx}
futures.append(future_dict)

# Progress bar
desc = "Translating total converted large chunk data" if large_chunk else "Translating total converted data"
progress_bar = tqdm(total=len(futures), desc=desc)
# Manually refresh the progress bar to display it
progress_bar.refresh()

tmp_finished_task = 0
# Wait for all threads to complete
while finished_task < len(futures):
if tmp_finished_task != finished_task:
progress_bar.update(1)
tmp_finished_task = finished_task
for future_dict in futures:
# If exception occurs in one of the thread, restart the thread with its specific chunk
if future_dict['future'].exception():
Expand Down

0 comments on commit 39172ba

Please sign in to comment.