diff --git a/src/toolviper/dask/client.py b/src/toolviper/dask/client.py index 6fb3619..ce7fb85 100644 --- a/src/toolviper/dask/client.py +++ b/src/toolviper/dask/client.py @@ -1,22 +1,90 @@ -import os -import dask -import psutil import logging +import multiprocessing +import os import pathlib -import distributed +from importlib import import_module +from importlib.util import find_spec +from typing import Dict, Union + +import dask import dask_jobqueue -import multiprocessing -import toolviper.dask.menrva +import distributed +import psutil -import toolviper.utils.parameter as parameter -import toolviper.utils.logger as logger +import toolviper.dask.menrva import toolviper.utils.console as console - -from typing import Union, Dict +import toolviper.utils.logger as logger +import toolviper.utils.parameter as parameter colorize = console.Colorize() +def load_libraries( + name: str, + libs: Union[str, list[str]] +) -> dict[str, bool]: + """Load libraries if they were installed and can be loaded. + + Parameters + ---------- + name : library group name + A library group name based on a function of a distributed environment will be imported. + libs : Union[str, list[str]] + a library or a list of libraries to import + + Returns + ------- + an item of dict has the name and the flag whether all libraries were loaded successfully. + """ + def _load_library(_lib): + if find_spec(_lib) is not None: + import_module(_lib) + return [True, f" {colorize.blue(_lib)} is available"] + else: + return [False, f" {colorize.blue(_lib)} is unavailable"] + + if isinstance(libs, list): + _tmp = list(map(_load_library, libs)) + _avail = [all([x[0] for x in _tmp]), [x[1] for x in _tmp]] + elif isinstance(libs, str): + _tmp = _load_library(libs) + _avail = [_tmp[0], [_tmp[1]]] + else: + _avail = [False, " illegal module specification"] + + _result = "Success" if _avail[0] else "Fail" + logger.info(f'Loading module: {name} -- {_result}') + [logger.info(x) for x in _avail[1]] + + return {name: _avail[0]} + + +def print_libraries_availability(spec: dict[str, bool]): + """Print contents of available_specs. + + Parameters + ---------- + spec : dict[str, bool] + an instance of available_specs + """ + loaded_lib = [ k for k, v in spec.items() if v ] + logger.info(f"{colorize.green('Available functions of this environment')}: {', '.join(loaded_lib)}") + + +""" +load libraries related functions of a distributed environment +'available_specs' contains the function name and a flag that the function was loaded successfully +""" + +logger.info(colorize.green("Checking functions availability:")) +available_specs = { + **load_libraries("slurm", "dask_jobqueue"), + **load_libraries("dask_ssh", ["asyncssh", "jupyter_server_proxy", "paramiko"]), + **load_libraries("CUDA", "dask_cuda") +} +print_libraries_availability(available_specs) + + def get_thread_info() -> Dict[str, float]: # This just brings the built-in thread info function into the client module. return toolviper.dask.menrva.MenrvaClient.thread_info()