diff --git a/setup.py b/setup.py index a3f82fd..661db18 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,8 @@ 'pandas', f'vantage6-common=={version_ns["__version__"]}', 'pyfiglet==0.8.post1', - 'SPARQLWrapper==1.8.5' + 'SPARQLWrapper==1.8.5', + 'rich', ], tests_require=["pytest"], package_data={ diff --git a/vantage6/client/__init__.py b/vantage6/client/__init__.py index 29055ff..57d8c57 100644 --- a/vantage6/client/__init__.py +++ b/vantage6/client/__init__.py @@ -4,14 +4,25 @@ This module is contains a base client. From this base client the container client (client used by master algorithms) and the user client are derived. """ +from typing import Optional + +import sys import logging import pickle import time +from tempfile import TemporaryFile import typing import jwt +from numpy import isin import requests import pyfiglet import json as json_lib +import itertools + +import dateutil +import time +from datetime import datetime + from pathlib import Path from typing import Tuple @@ -622,6 +633,109 @@ def authenticate(self, username: str, password: str) -> None: self.log.info(f'--> Retrieving additional user info failed!') self.log.debug(e) + def wait_for_results(self, task_or_id, sleep=1): + + # Disable logging + if isinstance(self.log, logging.Logger): + prev_level = self.log.level + self.log.setLevel(logging.WARN) + elif isinstance(self.log, UserClient.Log): + prev_level = self.log.enabled + + # Retrieve task details if necesary. + if isinstance(task_or_id, int): + task_id = task_or_id + task = self.task.get(task_id) + else: + task = task_or_id + task_id = task['id'] + + # Determine when the task was first started. We'll use the 1st result + # to determine this, since task itself doesn't record this. + if task['results']: + result_id = task['results'][0]['id'] + result = self.result.get(result_id) + start = dateutil.parser.isoparse(result['assigned_at']) + else: + start = None + result = None + + try: + from rich.progress import Progress, ProgressColumn, SpinnerColumn, TimeElapsedColumn + from rich.text import Text + from rich.table import Column + + class TrueTimeElapsedColumn(ProgressColumn): + """Renders time elapsed.""" + def __init__(self, start: datetime, table_column: Optional[Column] = None): + super().__init__(table_column) + self.start = start + + def render(self, task: "Task") -> Text: + """Show time remaining.""" + if self.start is None: + return Text("-:--:--", style="progress.elapsed") + + # elapsed = task.finished_time if task.finished else task.elapsed + if task.fields.get('my_finished_date'): + now = dateutil.parser.isoparse(task.fields.get('my_finished_date')) + else: + now = datetime.now(start.tzinfo) + + delta = now - self.start + d = str(delta).split('.')[0] + return Text(d, style="progress.elapsed") + + cols = [ + "[progress.description]{task.description}", + SpinnerColumn(), + "Time elapsed:", + TrueTimeElapsedColumn(start), + ", Last check at: {task.fields[last_check]}" + ] + + if result: + finished_at = result['finished_at'] + else: + finished_at = None + + with Progress(*cols) as progress: + ptask = progress.add_task( + f"Waiting for task {task_id}", + last_check=time.strftime('%H:%M:%S'), + my_finished_date=finished_at, + ) + + while not self.task.get(task_id)['complete']: + time.sleep(sleep) + + progress.update( + ptask, + last_check=time.strftime('%H:%M:%S'), + my_finished_date=None, + ) + + except ImportError: + animation = itertools.cycle(['|', '/', '-', '\\']) + t = time.time() + + while not self.task.get(task_id)['complete']: + frame = next(animation) + sys.stdout.write(f'\r{frame} Waiting for task {task_id} ({int(time.time()-t)}s)') + sys.stdout.flush() + time.sleep(sleep) + + sys.stdout.write('\rDone! ') + + + # Re-enable logging + if isinstance(self.log, logging.Logger): + self.log.setLevel(prev_level) + elif isinstance(self.log, UserClient.Log): + self.log.enabled = prev_level + + return self.get_results(task_id=task_id) + class Util(ClientBase.SubClient): """Collection of general utilities"""