Skip to content

Commit

Permalink
added library loader
Browse files Browse the repository at this point in the history
  • Loading branch information
katz-s committed Oct 10, 2024
1 parent 925ef5f commit 572264e
Showing 1 changed file with 78 additions and 10 deletions.
88 changes: 78 additions & 10 deletions src/toolviper/dask/client.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,90 @@
import os
import dask
import psutil
import logging
import multiprocessing
import os
import pathlib
import distributed
from importlib import import_module
from importlib.util import find_spec
from typing import Dict, Union

import dask
import dask_jobqueue
import multiprocessing
import toolviper.dask.menrva
import distributed
import psutil

import toolviper.utils.parameter as parameter
import toolviper.utils.logger as logger
import toolviper.dask.menrva
import toolviper.utils.console as console

from typing import Union, Dict
import toolviper.utils.logger as logger
import toolviper.utils.parameter as parameter

colorize = console.Colorize()


def load_libraries(
name: str,
libs: Union[str, list[str]]
) -> dict[str, bool]:
"""Load libraries if they were installed and can be loaded.
Parameters
----------
name : library group name
A library group name based on a function of a distributed environment will be imported.
libs : Union[str, list[str]]
a library or a list of libraries to import
Returns
-------
an item of dict has the name and the flag whether all libraries were loaded successfully.
"""
def _load_library(_lib):
if find_spec(_lib) is not None:
import_module(_lib)
return [True, f" {colorize.blue(_lib)} is available"]
else:
return [False, f" {colorize.blue(_lib)} is unavailable"]

if isinstance(libs, list):
_tmp = list(map(_load_library, libs))
_avail = [all([x[0] for x in _tmp]), [x[1] for x in _tmp]]
elif isinstance(libs, str):
_tmp = _load_library(libs)
_avail = [_tmp[0], [_tmp[1]]]
else:
_avail = [False, " illegal module specification"]

_result = "Success" if _avail[0] else "Fail"
logger.info(f'Loading module: {name} -- {_result}')
[logger.info(x) for x in _avail[1]]

return {name: _avail[0]}


def print_libraries_availability(spec: dict[str, bool]):
"""Print contents of available_specs.
Parameters
----------
spec : dict[str, bool]
an instance of available_specs
"""
loaded_lib = [ k for k, v in spec.items() if v ]
logger.info(f"{colorize.green('Available functions of this environment')}: {', '.join(loaded_lib)}")


"""
load libraries related functions of a distributed environment
'available_specs' contains the function name and a flag that the function was loaded successfully
"""

logger.info(colorize.green("Checking functions availability:"))
available_specs = {
**load_libraries("slurm", "dask_jobqueue"),
**load_libraries("dask_ssh", ["asyncssh", "jupyter_server_proxy", "paramiko"]),
**load_libraries("CUDA", "dask_cuda")
}
print_libraries_availability(available_specs)


def get_thread_info() -> Dict[str, float]:
# This just brings the built-in thread info function into the client module.
return toolviper.dask.menrva.MenrvaClient.thread_info()
Expand Down

0 comments on commit 572264e

Please sign in to comment.