diff --git a/emodel_generalisation/adaptation.py b/emodel_generalisation/adaptation.py index d0fa246..ad38cc9 100644 --- a/emodel_generalisation/adaptation.py +++ b/emodel_generalisation/adaptation.py @@ -29,6 +29,7 @@ import numpy as np import pandas as pd import seaborn as sns +from bluepyparallel import evaluate from matplotlib.backends.backend_pdf import PdfPages from numpy.polynomial import Polynomial from scipy.signal import cspline2d @@ -41,7 +42,6 @@ from emodel_generalisation.model.evaluation import rin_evaluation from emodel_generalisation.model.modifiers import remove_axon from emodel_generalisation.model.modifiers import remove_soma -from emodel_generalisation.parallel import evaluate from emodel_generalisation.utils import FEATURE_FILTER from emodel_generalisation.utils import get_scores diff --git a/emodel_generalisation/bluecellulab_evaluator.py b/emodel_generalisation/bluecellulab_evaluator.py index 5951a75..4fbe8e8 100644 --- a/emodel_generalisation/bluecellulab_evaluator.py +++ b/emodel_generalisation/bluecellulab_evaluator.py @@ -6,9 +6,8 @@ import bluecellulab import efel - -from emodel_generalisation.parallel.evaluator import evaluate -from emodel_generalisation.parallel.parallel import NestedPool +from bluepyparallel.evaluator import evaluate +from bluepyparallel.parallel import NestedPool logger = logging.getLogger(__name__) AXON_LOC = "self.axonal[1](0.5)._ref_v" diff --git a/emodel_generalisation/cli.py b/emodel_generalisation/cli.py index d8760d7..e995523 100644 --- a/emodel_generalisation/cli.py +++ b/emodel_generalisation/cli.py @@ -13,6 +13,7 @@ import pandas as pd import seaborn as sns import yaml +from bluepyparallel import init_parallel_factory from datareuse import Reuse from matplotlib.backends.backend_pdf import PdfPages from morphio import Morphology @@ -32,7 +33,6 @@ from emodel_generalisation.model.evaluation import evaluate_rho_axon from emodel_generalisation.model.evaluation import feature_evaluation from emodel_generalisation.model.modifiers import get_replace_axon_hoc -from emodel_generalisation.parallel import init_parallel_factory from emodel_generalisation.utils import FEATURE_FILTER from emodel_generalisation.utils import get_feature_df from emodel_generalisation.utils import get_score_df diff --git a/emodel_generalisation/mcmc.py b/emodel_generalisation/mcmc.py index 89eb3de..8d81077 100644 --- a/emodel_generalisation/mcmc.py +++ b/emodel_generalisation/mcmc.py @@ -31,6 +31,8 @@ import numpy as np import pandas as pd import seaborn as sns +from bluepyparallel import evaluate +from bluepyparallel import init_parallel_factory from matplotlib.backends.backend_pdf import PdfPages from mpl_toolkits import axisartist from scipy.spatial import distance_matrix @@ -45,8 +47,6 @@ from emodel_generalisation.information import rsi_gaussian from emodel_generalisation.model.access_point import AccessPoint from emodel_generalisation.model.evaluation import get_evaluator_from_access_point -from emodel_generalisation.parallel import evaluate -from emodel_generalisation.parallel import init_parallel_factory from emodel_generalisation.utils import cluster_matrix # pylint: disable=too-many-lines,too-many-locals diff --git a/emodel_generalisation/model/evaluation.py b/emodel_generalisation/model/evaluation.py index 67b3a18..dfa3481 100644 --- a/emodel_generalisation/model/evaluation.py +++ b/emodel_generalisation/model/evaluation.py @@ -42,12 +42,12 @@ from bluepyopt.ephys.objectives import SingletonObjective from bluepyopt.ephys.objectivescalculators import ObjectivesCalculator from bluepyopt.ephys.simulators import NrnSimulator +from bluepyparallel import evaluate +from bluepyparallel.parallel import NestedPool from emodel_generalisation.model import bpopt from emodel_generalisation.model import modifiers from emodel_generalisation.model.ecodes import eCodes -from emodel_generalisation.parallel import evaluate -from emodel_generalisation.parallel.parallel import NestedPool # pylint: disable=too-many-lines diff --git a/emodel_generalisation/morph_utils.py b/emodel_generalisation/morph_utils.py index fd15e19..226c83e 100644 --- a/emodel_generalisation/morph_utils.py +++ b/emodel_generalisation/morph_utils.py @@ -27,6 +27,7 @@ import numpy as np import pandas as pd import yaml +from bluepyparallel.evaluator import evaluate from diameter_synthesis import build_diameters from diameter_synthesis import build_models from diameter_synthesis.main import plot_models @@ -38,8 +39,6 @@ from neurom import view from tqdm import tqdm -from emodel_generalisation.parallel.evaluator import evaluate - def create_combos_df( morphology_dataset_path, generalisation_rule_path, emodel, n_min_per_mtype, n_morphs diff --git a/emodel_generalisation/parallel/__init__.py b/emodel_generalisation/parallel/__init__.py deleted file mode 100644 index 4e140fa..0000000 --- a/emodel_generalisation/parallel/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""bluepyparallel package (to be open sourced as standalone package). - -Provides an embarrassingly parallel tool with sql backend. -""" -from emodel_generalisation.parallel.evaluator import evaluate # noqa -from emodel_generalisation.parallel.parallel import init_parallel_factory # noqa diff --git a/emodel_generalisation/parallel/database.py b/emodel_generalisation/parallel/database.py deleted file mode 100644 index 9da45da..0000000 --- a/emodel_generalisation/parallel/database.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Module used to provide a simple API to a database in which the results are stored.""" -import re - -import pandas as pd -from sqlalchemy import MetaData -from sqlalchemy import Table -from sqlalchemy import bindparam -from sqlalchemy import create_engine -from sqlalchemy import insert -from sqlalchemy import schema -from sqlalchemy import select -from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.exc import OperationalError -from sqlalchemy_utils import create_database -from sqlalchemy_utils import database_exists - -try: # pragma: no cover - import psycopg2 - import psycopg2.extras - - with_psycopg2 = True -except ImportError: - with_psycopg2 = False - - -class DataBase: - """A simple API to manage the database in which the results are inserted using SQLAlchemy. - - Args: - url (str): The URL of the database following the RFC-1738 format ( - https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls) - create (bool): If set to True, the database will be automatically created by the - constructor. - args and kwargs: They will be passed to the :func:`sqlalchemy.create_engine` function. - """ - - index_col = "df_index" - _url_pattern = r"[a-zA-Z0-9_\-\+]+://.*" - - def __init__(self, url, *args, create=False, **kwargs): - if not re.match(self._url_pattern, str(url)): - url = "sqlite:///" + str(url) - - self.engine = create_engine(url, *args, **kwargs) - - if create and not self.db_exists(): - create_database(self.engine.url) - - self._connection = None - self.metadata = None - self.table = None - - def __del__(self): - """Close the connection and the engine to the database.""" - self.connection.close() - self.engine.dispose() - - @property - def connection(self): - """Get a connection to the database.""" - try: - if self._connection.connection.dbapi_connection is None: - self._connection.close() - self._connection = None - except AttributeError: - self._connection = None - - if self._connection is None: - self._connection = self.engine.connect() - - return self._connection - - def get_url(self): - """Get the URL of the database.""" - return self.engine.url - - def create(self, df, table_name=None, schema_name=None): - """Create a table in the database in which the results will be written.""" - if table_name is None: - table_name = "df" - if schema_name is not None and schema_name not in self.connection.dialect.get_schema_names( - self.connection - ): # pragma: no cover - self.connection.execute(schema.CreateSchema(schema_name)) - new_df = df.loc[[]] - new_df.to_sql( - name=table_name, - con=self.connection, - schema=schema_name, - if_exists="replace", - index_label=self.index_col, - ) - self.reflect(table_name, schema_name) - - def db_exists(self): - """Check that the server and the database exist.""" - if with_psycopg2: # pragma: no cover - exceptions = (OperationalError, psycopg2.OperationalError) - else: - exceptions = (OperationalError,) - - try: - return database_exists(self.engine.url) - except exceptions: # pragma: no cover - return False - - def exists(self, table_name, schema_name=None): - """Check that the table exists in the database.""" - inspector = Inspector.from_engine(self.engine) - return table_name in inspector.get_table_names(schema=schema_name) - - def reflect(self, table_name, schema_name=None): - """Reflect the table from the database.""" - self.metadata = MetaData() - self.table = Table( - table_name, - self.metadata, - schema=schema_name, - autoload_with=self.engine, - ) - - def load(self): - """Load the table data from the database.""" - query = select(self.table) - return pd.read_sql(query, self.connection, index_col=self.index_col) - - def write(self, row_id, result=None, exception=None, **input_values): - """Write a result entry or an exception into the table.""" - if result is not None: - vals = result - elif exception is not None: - vals = {"exception": exception} - else: - return - - query = insert(self.table).values({**{self.index_col: row_id}, **vals, **input_values}) - self.connection.execute(query) - self.connection.connection.commit() - - def write_batch(self, columns, data): - """Write entries from a list of lists into the table.""" - if not data: # pragma: no cover - return - assert len(columns) + 1 == len( - data[0] - ), "The columns list must have one less entry than each data element" - cursor = self.connection.connection.cursor() - cols = {col: bindparam(col) for col in [self.index_col] + columns} - # pylint: disable=no-value-for-parameter - compiled = self.table.insert().values(**cols).compile(dialect=self.engine.dialect) - - if hasattr(cursor, "mogrify") and with_psycopg2: # pragma: no cover - psycopg2.extras.execute_values(cursor, str(compiled), data) - else: - cursor.executemany(str(compiled), data) - - self.connection.connection.commit() - self.connection.connection.close() diff --git a/emodel_generalisation/parallel/evaluator.py b/emodel_generalisation/parallel/evaluator.py deleted file mode 100644 index c068a59..0000000 --- a/emodel_generalisation/parallel/evaluator.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Module to evaluate generic functions on rows of dataframe.""" -import logging -import sys -import traceback -from functools import partial - -import pandas as pd -from tqdm import tqdm - -from emodel_generalisation.parallel.database import DataBase -from emodel_generalisation.parallel.parallel import DaskDataFrameFactory -from emodel_generalisation.parallel.parallel import init_parallel_factory - -logger = logging.getLogger(__name__) - - -def _try_evaluation(task, evaluation_function, func_args, func_kwargs): - """Encapsulate the evaluation function into a try/except and isolate to record exceptions.""" - task_id, task_args = task - - try: - result = evaluation_function(task_args, *func_args, **func_kwargs) - exception = None - except Exception: # pylint: disable=broad-except - result = {} - exception = "".join(traceback.format_exception(*sys.exc_info())) - logger.exception("Exception for ID=%s: %s", task_id, exception) - - return task_id, result, exception - - -def _try_evaluation_df(task, evaluation_function, func_args, func_kwargs): - task_id, result, exception = _try_evaluation( - (task.name, task.to_dict()), - evaluation_function, - func_args, - func_kwargs, - ) - res_cols = list(result.keys()) - result["exception"] = exception - return pd.Series(result, name=task_id, dtype="object", index=["exception"] + res_cols) - - -def _evaluate_dataframe( - to_evaluate, - input_cols, - evaluation_function, - func_args, - func_kwargs, - new_columns, - mapper, - task_ids, - db, -): - """Internal evaluation function for dask.dataframe.""" - # Setup the function to apply to the data - eval_func = partial( - _try_evaluation_df, - evaluation_function=evaluation_function, - func_args=func_args, - func_kwargs=func_kwargs, - ) - meta = pd.DataFrame({col[0]: pd.Series(dtype="object") for col in new_columns}) - - res = [] - try: - # Compute and collect the results - for batch in mapper(eval_func, to_evaluate.loc[task_ids, input_cols], meta=meta): - res.append(batch) - - if db is not None: - batch_complete = to_evaluate[input_cols].join(batch, how="right") - data = batch_complete.to_records().tolist() - db.write_batch(batch_complete.columns.tolist(), data) - except (KeyboardInterrupt, SystemExit) as ex: # pragma: no cover - # To save dataframe even if program is killed - logger.warning("Stopping mapper loop. Reason: %r", ex) - return pd.concat(res) - - -def _evaluate_basic( - to_evaluate, input_cols, evaluation_function, func_args, func_kwargs, mapper, task_ids, db -): - res = [] - # Setup the function to apply to the data - eval_func = partial( - _try_evaluation, - evaluation_function=evaluation_function, - func_args=func_args, - func_kwargs=func_kwargs, - ) - - # Split the data into rows - arg_list = list(to_evaluate.loc[task_ids, input_cols].to_dict("index").items()) - - try: - # Compute and collect the results - for task_id, result, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)): - res.append(dict({"df_index": task_id, "exception": exception}, **result)) - - # Save the results into the DB - if db is not None: - db.write( - task_id, result, exception, **to_evaluate.loc[task_id, input_cols].to_dict() - ) - except (KeyboardInterrupt, SystemExit) as ex: - # To save dataframe even if program is killed - logger.warning("Stopping mapper loop. Reason: %r", ex) - - # Gather the results to the output DataFrame - return pd.DataFrame(res).set_index("df_index") - - -def _prepare_db(db_url, to_evaluate, df, resume, task_ids): - """Prepare db.""" - db = DataBase(db_url) - - if resume and db.exists("df"): - logger.info("Load data from SQL database") - db.reflect("df") - previous_results = db.load() - previous_idx = previous_results.index - bad_cols = [ - col - for col in df.columns - if not to_evaluate.loc[previous_idx, col].equals(previous_results[col]) - ] - if bad_cols: - raise ValueError( - f"The following columns have different values from the DataBase: {bad_cols}" - ) - to_evaluate.loc[previous_results.index] = previous_results.loc[previous_results.index] - task_ids = task_ids.difference(previous_results.index) - else: - logger.info("Create SQL database") - db.create(to_evaluate) - - return db, db.get_url(), task_ids - - -def evaluate( - df, - evaluation_function, - new_columns=None, - resume=False, - parallel_factory=None, - db_url=None, - func_args=None, - func_kwargs=None, - **mapper_kwargs, -): - """Evaluate and save results in a sqlite database on the fly and return dataframe. - - Args: - df (pandas.DataFrame): each row contains information for the computation. - evaluation_function (callable): function used to evaluate each row, - should have a single argument as list-like containing values of the rows of df, - and return a dict with keys corresponding to the names in new_columns. - new_columns (list): list of names of new column and empty value to save evaluation results, - i.e.: :code:`[['result', 0.0], ['valid', False]]`. - resume (bool): if :obj:`True` and ``db_url`` is provided, it will use only compute the - missing rows of the database. - parallel_factory (ParallelFactory or str): parallel factory name or instance. - db_url (str): should be DB URL that can be interpreted by :func:`sqlalchemy.create_engine` - or can be a file path that is interpreted as a SQLite database. If an URL is given, - the SQL backend will be enabled to store results and allowing future resume. Should - not be used when evaluations are numerous and fast, in order to avoid the overhead of - communication with the SQL database. - func_args (list): the arguments to pass to the evaluation_function. - func_kwargs (dict): the keyword arguments to pass to the evaluation_function. - **mapper_kwargs: the keyword arguments are passed to the get_mapper() method of the - :class:`ParallelFactory` instance. - - Return: - pandas.DataFrame: dataframe with new columns containing the computed results. - """ - # Initialize the parallel factory - if isinstance(parallel_factory, str) or parallel_factory is None: - parallel_factory = init_parallel_factory(parallel_factory) - # Set default args - if func_args is None: - func_args = [] - - # Set default kwargs - if func_kwargs is None: - func_kwargs = {} - - # Drop exception column if present - if "exception" in df.columns: - df = df.drop(columns=["exception"]) - - # Shallow copy the given DataFrame to add internal rows - to_evaluate = df.copy() - task_ids = to_evaluate.index - - # Set default new columns - if new_columns is None: - if isinstance(parallel_factory, DaskDataFrameFactory): - raise ValueError("The new columns must be provided when using 'DaskDataFrameFactory'") - new_columns = [] - - # Setup internal and new columns - if any(col[0] == "exception" for col in new_columns): - raise ValueError("The 'exception' column can not be one of the new columns") - new_columns = [["exception", None]] + new_columns # Don't use append to keep the input as is. - for new_column in new_columns: - to_evaluate[new_column[0]] = new_column[1] - - # Create the database if required and get the task ids to run - if db_url is None: - logger.debug("Not using SQL backend to save iterations") - db = None - else: - db, db_url, task_ids = _prepare_db(db_url, to_evaluate, df, resume, task_ids) - - # Log the number of tasks to run - if len(task_ids) > 0: - logger.info("%s rows to compute.", str(len(task_ids))) - else: - logger.warning("WARNING: No row to compute, something may be wrong") - return to_evaluate - - # Get the factory mapper - mapper = parallel_factory.get_mapper(**mapper_kwargs) - - if isinstance(parallel_factory, DaskDataFrameFactory): - res_df = _evaluate_dataframe( - to_evaluate, - df.columns, - evaluation_function, - func_args, - func_kwargs, - new_columns, - mapper, - task_ids, - db, - ) - else: - res_df = _evaluate_basic( - to_evaluate, - df.columns, - evaluation_function, - func_args, - func_kwargs, - mapper, - task_ids, - db, - ) - to_evaluate.loc[res_df.index, res_df.columns] = res_df - - return to_evaluate diff --git a/emodel_generalisation/parallel/parallel.py b/emodel_generalisation/parallel/parallel.py deleted file mode 100644 index c41dd4e..0000000 --- a/emodel_generalisation/parallel/parallel.py +++ /dev/null @@ -1,358 +0,0 @@ -"""Parallel helper.""" -import logging -import multiprocessing -import os -from abc import abstractmethod -from collections.abc import Iterator -from functools import partial -from multiprocessing.pool import Pool - -import numpy as np - -try: - import dask.distributed - import dask_mpi - - dask_available = True -except ImportError: # pragma: no cover - dask_available = False - -try: - import dask.dataframe as dd # pylint: disable=ungrouped-imports - import pandas as pd - from dask.distributed import progress - - dask_df_available = True -except ImportError: # pragma: no cover - dask_df_available = False - -try: - import ipyparallel - - ipyparallel_available = True -except ImportError: # pragma: no cover - ipyparallel_available = False - - -L = logging.getLogger(__name__) - - -def _func_wrapper(data, func, func_args, func_kwargs): - """Function wrapper used to pass args and kwargs.""" - return func(data, *func_args, **func_kwargs) - - -class ParallelFactory: - """Abstract class that should be subclassed to provide parallel functions.""" - - _BATCH_SIZE = "PARALLEL_BATCH_SIZE" - _CHUNK_SIZE = "PARALLEL_CHUNK_SIZE" - - # pylint: disable=unused-argument - def __init__(self, batch_size=None, chunk_size=None): - self.batch_size = batch_size or int(os.getenv(self._BATCH_SIZE, "0")) or None - L.debug("Using %s=%s", self._BATCH_SIZE, self.batch_size) - - self.chunk_size = batch_size or int(os.getenv(self._CHUNK_SIZE, "0")) or None - L.debug("Using %s=%s", self._CHUNK_SIZE, self.chunk_size) - - if not hasattr(self, "nb_processes"): - self.nb_processes = 1 - - def __del__(self): - """Call the shutdown method.""" - self.shutdown() - - @abstractmethod - def get_mapper(self, batch_size=None, chunk_size=None, **kwargs): - """Return a mapper function that can be used to execute functions in parallel.""" - - def shutdown(self): - """Can be used to cleanup.""" - - def mappable_func(self, func, *args, **kwargs): - """Can be used to add args and kwargs to a function before calling the mapper.""" - return partial(_func_wrapper, func=func, func_args=args, func_kwargs=kwargs) - - def _with_batches(self, mapper, func, iterable, batch_size=None): - """Wrapper on mapper function creating batches of iterable to give to mapper. - - The batch_size is an int corresponding to the number of evaluation in each batch. - """ - if isinstance(iterable, Iterator): - iterable = list(iterable) - - batch_size = batch_size or self.batch_size - if batch_size is not None: - iterables = np.array_split(iterable, len(iterable) // min(batch_size, len(iterable))) - if not isinstance(iterable, (pd.DataFrame, pd.Series)): - iterables = [_iterable.tolist() for _iterable in iterables] - else: - iterables = [iterable] - - for i, _iterable in enumerate(iterables): - if len(iterables) > 1: - L.info("Computing batch %s / %s", i + 1, len(iterables)) - yield from mapper(func, _iterable) - - def _chunksize_to_kwargs(self, chunk_size, kwargs, label="chunk_size"): - chunk_size = chunk_size or self.chunk_size - if chunk_size is not None: - kwargs[label] = chunk_size - - -class NoDaemonProcess(multiprocessing.Process): - """Class that represents a non-daemon process.""" - - # pylint: disable=dangerous-default-value - - def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): - """Ensures group=None, for macosx.""" - super().__init__(group=None, target=target, name=name, args=args, kwargs=kwargs) - - def _get_daemon(self): - """Get daemon flag.""" - return False # pragma: no cover - - def _set_daemon(self, value): - """Set daemon flag.""" - - daemon = property(_get_daemon, _set_daemon) - - -class NestedPool(Pool): # pylint: disable=abstract-method - """Class that represents a MultiProcessing nested pool.""" - - Process = NoDaemonProcess - - -class SerialFactory(ParallelFactory): - """Factory that do not work in parallel.""" - - def get_mapper(self, batch_size=None, chunk_size=None, **kwargs): - """Get a map.""" - - def _mapper(func, iterable, *func_args, **func_kwargs): - mapped_func = self.mappable_func(func, *func_args, **func_kwargs) - return self._with_batches(map, mapped_func, iterable) - - return _mapper - - -class MultiprocessingFactory(ParallelFactory): - """Parallel helper class using multiprocessing.""" - - _CHUNKSIZE = "PARALLEL_CHUNKSIZE" - - def __init__(self, batch_size=None, chunk_size=None, processes=None, **kwargs): - """Initialize multiprocessing factory.""" - super().__init__(batch_size, chunk_size) - - self.nb_processes = processes or os.cpu_count() - self.pool = NestedPool(processes=self.nb_processes, **kwargs) - - def get_mapper(self, batch_size=None, chunk_size=None, **kwargs): - """Get a NestedPool.""" - self._chunksize_to_kwargs(chunk_size, kwargs, label="chunksize") - - def _mapper(func, iterable, *func_args, **func_kwargs): - mapped_func = self.mappable_func(func, *func_args, **func_kwargs) - return self._with_batches( - partial(self.pool.imap_unordered, **kwargs), - mapped_func, - iterable, - ) - - return _mapper - - -class IPyParallelFactory(ParallelFactory): - """Parallel helper class using ipyparallel.""" - - _IPYTHON_PROFILE = "IPYTHON_PROFILE" - - def __init__(self, batch_size=None, chunk_size=None, profile=None, **kwargs): - """Initialize the ipyparallel factory.""" - profile = profile or os.getenv(self._IPYTHON_PROFILE, None) - L.debug("Using %s=%s", self._IPYTHON_PROFILE, profile) - self.rc = ipyparallel.Client(profile=profile, **kwargs) - self.nb_processes = len(self.rc.ids) - self.lview = self.rc.load_balanced_view() - super().__init__(batch_size, chunk_size) - - def get_mapper(self, batch_size=None, chunk_size=None, **kwargs): - """Get an ipyparallel mapper using the profile name provided.""" - if "ordered" not in kwargs: # pragma: no cover - kwargs["ordered"] = False - - self._chunksize_to_kwargs(chunk_size, kwargs) - - def _mapper(func, iterable, *func_args, **func_kwargs): - mapped_func = self.mappable_func(func, *func_args, **func_kwargs) - return self._with_batches( - partial(self.lview.imap, **kwargs), mapped_func, iterable, batch_size=batch_size - ) - - return _mapper - - def shutdown(self): - """Remove zmq.""" - try: - self.rc.close() - except Exception: # pragma: no cover ; pylint: disable=broad-except - pass - - -class DaskFactory(ParallelFactory): - """Parallel helper class using dask.""" - - _SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH" - - def __init__( - self, batch_size=None, chunk_size=None, scheduler_file=None, address=None, **kwargs - ): - """Initialize the dask factory.""" - dask_scheduler_path = scheduler_file or os.getenv(self._SCHEDULER_PATH) - self.interactive = True - if dask_scheduler_path: # pragma: no cover - L.debug("Connecting dask_mpi with scheduler %s", dask_scheduler_path) - if address: # pragma: no cover - L.debug("Connecting dask_mpi with address %s", address) - if not dask_scheduler_path and not address: # pragma: no cover - self.interactive = False - # local_directory is a fix for https://github.com/dask/dask-mpi/pull/114 - dask_mpi.initialize(local_directory=None) - L.debug("Starting dask_mpi...") - - self.client = dask.distributed.Client( - address=address, - scheduler_file=dask_scheduler_path, - **kwargs, - ) - - if self.interactive: - self.nb_processes = len(self.client.scheduler_info()["workers"]) - else: # pragma: no cover - from mpi4py import MPI # pylint: disable=import-outside-toplevel,import-error - - comm = MPI.COMM_WORLD # pylint: disable=c-extension-no-member - self.nb_processes = comm.Get_size() - - super().__init__(batch_size, chunk_size) - - def shutdown(self): - """Close the scheduler and the cluster if it was created by the factory.""" - try: - self.client.close() - except Exception: # pylint: disable=broad-except ; # pragma: no cover - pass - - def get_mapper(self, batch_size=None, chunk_size=None, **kwargs): - """Get a Dask mapper.""" - self._chunksize_to_kwargs(chunk_size, kwargs, label="batch_size") - - def _mapper(func, iterable, *func_args, **func_kwargs): - def _dask_mapper(in_dask_func, iterable): - futures = self.client.map(in_dask_func, iterable, **kwargs) - for _future, result in dask.distributed.as_completed(futures, with_results=True): - yield result - - mapped_func = self.mappable_func(func, *func_args, **func_kwargs) - return self._with_batches(_dask_mapper, mapped_func, iterable, batch_size=batch_size) - - return _mapper - - -class DaskDataFrameFactory(DaskFactory): - """Parallel helper class using dask.dataframe.""" - - _SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH" - - def __init__( - self, - batch_size=None, - chunk_size=None, - scheduler_file=None, - address=None, - dask_config=None, - **kwargs, - ): - super().__init__( - batch_size, chunk_size, scheduler_file=scheduler_file, address=address, **kwargs - ) - if dask_config is None: # pragma: no cover - dask_config = { - "distributed.worker.use-file-locking": False, - "distributed.worker.memory.target": False, - "distributed.worker.memory.spill": False, - "distributed.worker.memory.pause": 0.8, - "distributed.worker.memory.terminate": 0.95, - "distributed.worker.profile.interval": "10000ms", - "distributed.worker.profile.cycle": "1000000ms", - "distributed.admin.tick.limit": "1h", - } - - dask.config.set(dask_config) - - def _with_batches(self, *args, **kwargs): - """Specific process for batches.""" - for tmp in super()._with_batches(*args, **kwargs): - if isinstance(tmp, pd.Series): - tmp = tmp.to_frame() - yield tmp - - def get_mapper(self, batch_size=None, chunk_size=None, **kwargs): - """Get a Dask mapper.""" - self._chunksize_to_kwargs(chunk_size, kwargs, label="chunksize") - if not kwargs.get("chunksize"): - kwargs["npartitions"] = self.nb_processes or 1 - - def _mapper(func, iterable, *func_args, meta, **func_kwargs): - def _dask_df_mapper(func, iterable): - df = pd.DataFrame(iterable) - ddf = dd.from_pandas(df, **kwargs) - future = ddf.apply(func, meta=meta, axis=1).persist() - if not os.environ.get("NO_PROGRESS", False): - progress(future) - # Put into a list because of the 'yield from' in _with_batches - return [future.compute()] - - func = self.mappable_func(func, *func_args, **func_kwargs) - return self._with_batches(_dask_df_mapper, func, iterable, batch_size=batch_size) - - return _mapper - - -def init_parallel_factory(parallel_lib, *args, **kwargs): - """Return the desired instance of the parallel factory. - - The main factories are: - - * None: return a serial mapper (the standard :func:`map` function). - * multiprocessing: return a mapper using the standard :mod:`multiprocessing`. - * dask: return a mapper using the :class:`distributed.Client`. - * ipyparallel: return a mapper using the :mod:`ipyparallel` library. - """ - parallel_factories = { - None: SerialFactory, - "multiprocessing": MultiprocessingFactory, - } - if dask_available: # pragma: no cover - parallel_factories["dask"] = DaskFactory - if dask_df_available: # pragma: no cover - parallel_factories["dask_dataframe"] = DaskDataFrameFactory - if ipyparallel_available: # pragma: no cover - parallel_factories["ipyparallel"] = IPyParallelFactory - - try: - parallel_factory = parallel_factories[parallel_lib](*args, **kwargs) - except KeyError: - L.critical( - "The %s factory is not available, maybe the required libraries are not properly " - "installed.", - parallel_lib, - ) - raise - L.info("Initialized %s factory", parallel_lib) - - return parallel_factory diff --git a/emodel_generalisation/tasks/utils.py b/emodel_generalisation/tasks/utils.py index 84f39b1..f8f7b50 100644 --- a/emodel_generalisation/tasks/utils.py +++ b/emodel_generalisation/tasks/utils.py @@ -23,10 +23,10 @@ from pathlib import Path import luigi +from bluepyparallel import init_parallel_factory from luigi_tools.target import OutputLocalTarget from emodel_generalisation.model.access_point import AccessPoint -from emodel_generalisation.parallel import init_parallel_factory class EmodelAPIConfig(luigi.Config): diff --git a/setup.py b/setup.py index c9e14ad..21bd3a2 100644 --- a/setup.py +++ b/setup.py @@ -43,12 +43,8 @@ "xgboost>=1.7.5,<2", "pyyaml>=6", "datareuse>=0.0.3", - "ipyparallel>=6.3,<7", - "dask[dataframe, distributed]>=2023.3.2", - "dask-mpi>=2022.4", - "sqlalchemy>=1.4.24", - "sqlalchemy-utils>=0.37.2", - "bluecellulab>=1.7.6", + "bluepyparallel>=0.2.1", + "bluecellulab>=1.7.6,<=2.3.1", "voxcell>=3.1.5", "efel>=5.5.5", ] diff --git a/tests/test_cli.py b/tests/test_cli.py index 951663f..c81609f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -23,7 +23,6 @@ def test_compute_currents(cli_runner, tmpdir): "--morphology-path", str(DATA / "morphologies"), "--protocol-config-path", str(DATA / "protocol_config.yaml"), "--hoc-path", str(DATA / "hoc"), - "--parallel-lib", None, ], ) # fmt: on @@ -53,7 +52,6 @@ def test_compute_currents(cli_runner, tmpdir): "--morphology-path", str(DATA / "morphologies"), "--protocol-config-path", str(DATA / "protocol_config.yaml"), "--hoc-path", str(DATA / "hoc"), - "--parallel-lib", None, "--only-rin", ], ) @@ -79,7 +77,6 @@ def test_evaluate(cli_runner, tmpdir): "--morphology-path", str(DATA / "morphologies"), "--config-path", str(DATA / "config"), "--final-path", str(DATA / "final.json"), - "--parallel-lib", None, "--evaluate-all", ], ) @@ -125,7 +122,6 @@ def test_adapt(cli_runner, tmpdir): "--final-path", str(DATA / "final.json"), "--local-dir", str(tmpdir / 'local'), "--output-hoc-path", str(tmpdir / "hoc"), - "--parallel-lib", None, "--min-scale", 0.9, "--max-scale", 1.1, ], @@ -153,7 +149,6 @@ def test_adapt(cli_runner, tmpdir): "--config-path", str(DATA / "config"), "--local-dir", str(tmpdir / 'local'), "--final-path", str(DATA / "final.json"), - "--parallel-lib", None, ], ) # fmt: on @@ -192,7 +187,6 @@ def test_adapt(cli_runner, tmpdir): "--morphology-path", str(DATA / "morphologies"), "--protocol-config-path", str(DATA / "protocol_config.yaml"), "--hoc-path", str(tmpdir / "hoc"), - "--parallel-lib", None, ], ) # fmt: on @@ -204,6 +198,7 @@ def test_adapt(cli_runner, tmpdir): [-72.841806, -71.32893], rtol=1e-5, ) + df.to_csv("test.csv") npt.assert_allclose( df["@dynamics:input_resistance"].to_list(), [105.342194, 1863.809101],