diff --git a/ci/run_cudf_polars_pytests.sh b/ci/run_cudf_polars_pytests.sh index bf5a3ccee8e..e881055e9e3 100755 --- a/ci/run_cudf_polars_pytests.sh +++ b/ci/run_cudf_polars_pytests.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. set -euo pipefail @@ -13,3 +13,9 @@ python -m pytest --cache-clear "$@" tests # Test the "dask-experimental" executor python -m pytest --cache-clear "$@" tests --executor dask-experimental + +# Test the "dask-experimental" executor with Distributed cluster +# Not all tests pass yet, deselecting by name those that are failing. +python -m pytest --cache-clear "$@" tests --executor dask-experimental --dask-cluster \ + -k "not test_groupby_maintain_order_random and not test_scan_csv_multi and not test_select_literal_series" \ + --cov-fail-under=89 # Override coverage, Distributed cluster coverage not yet 100% diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index 16290fdb663..e81866e68e4 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -7,7 +7,7 @@ import itertools import operator from functools import reduce -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar import cudf_polars.experimental.io import cudf_polars.experimental.join @@ -24,10 +24,38 @@ if TYPE_CHECKING: from collections.abc import MutableMapping + from distributed import Client + from cudf_polars.containers import DataFrame from cudf_polars.experimental.dispatch import LowerIRTransformer +class SerializerManager: + """Manager to ensure ensure serializer is only registered once.""" + + _serializer_registered: bool = False + _client_run_executed: ClassVar[set[str]] = set() + + @classmethod + def register_serialize(cls) -> None: + """Register Dask/cudf-polars serializers in calling process.""" + if not cls._serializer_registered: + from cudf_polars.experimental.dask_serialize import register + + register() + cls._serializer_registered = True + + @classmethod + def run_on_cluster(cls, client: Client) -> None: + """Run serializer registration on the workers and scheduler.""" + if ( + client.id not in cls._client_run_executed + ): # pragma: no cover; Only executes with Distributed scheduler + client.run(cls.register_serialize) + client.run_on_scheduler(cls.register_serialize) + cls._client_run_executed.add(client.id) + + @lower_ir_node.register(IR) def _(ir: IR, rec: LowerIRTransformer) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: # Default logic - Requires single partition @@ -127,12 +155,32 @@ def task_graph( return graph, (key_name, 0) +def get_client(): + """Get appropriate Dask client or scheduler.""" + SerializerManager.register_serialize() + + try: # pragma: no cover; block depends on executor type and Distributed cluster + from distributed import get_client + + client = get_client() + SerializerManager.run_on_cluster(client) + except ( + ImportError, + ValueError, + ): # pragma: no cover; block depends on Dask local scheduler + from dask import get + + return get + else: # pragma: no cover; block depends on executor type and Distributed cluster + return client.get + + def evaluate_dask(ir: IR) -> DataFrame: """Evaluate an IR graph with Dask.""" - from dask import get - ir, partition_info = lower_ir_graph(ir) + get = get_client() + graph, key = task_graph(ir, partition_info) return get(graph, key) diff --git a/python/cudf_polars/tests/conftest.py b/python/cudf_polars/tests/conftest.py index 6338bf0cae1..dbd0989a8b2 100644 --- a/python/cudf_polars/tests/conftest.py +++ b/python/cudf_polars/tests/conftest.py @@ -1,9 +1,11 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import pytest +DISTRIBUTED_CLUSTER_KEY = pytest.StashKey[dict]() + @pytest.fixture(params=[False, True], ids=["no_nulls", "nulls"], scope="session") def with_nulls(request): @@ -19,8 +21,50 @@ def pytest_addoption(parser): help="Executor to use for GPUEngine.", ) + parser.addoption( + "--dask-cluster", + action="store_true", + help="Executor to use for GPUEngine.", + ) + def pytest_configure(config): import cudf_polars.testing.asserts + if ( + config.getoption("--dask-cluster") + and config.getoption("--executor") != "dask-experimental" + ): + raise pytest.UsageError( + "--dask-cluster requires --executor='dask-experimental'" + ) + cudf_polars.testing.asserts.Executor = config.getoption("--executor") + + +def pytest_sessionstart(session): + if ( + session.config.getoption("--dask-cluster") + and session.config.getoption("--executor") == "dask-experimental" + ): + from dask import config + from dask.distributed import Client, LocalCluster + + # Avoid "Sending large graph of size ..." warnings + # (We expect these for tests using literal/random arrays) + config.set({"distributed.admin.large-graph-warning-threshold": "20MB"}) + + cluster = LocalCluster() + client = Client(cluster) + session.stash[DISTRIBUTED_CLUSTER_KEY] = {"cluster": cluster, "client": client} + + +def pytest_sessionfinish(session): + if DISTRIBUTED_CLUSTER_KEY in session.stash: + cluster_info = session.stash[DISTRIBUTED_CLUSTER_KEY] + client = cluster_info.get("client") + cluster = cluster_info.get("cluster") + if client is not None: + client.shutdown() + if cluster is not None: + cluster.close()