Skip to content

Commit

Permalink
Add context manager for cufile driver properties
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Feb 26, 2025
1 parent 2d143d8 commit 0f681dd
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 72 deletions.
2 changes: 1 addition & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ Defaults

.. autofunction:: http_status_codes

.. autofunction:: kvikio.defaults.http_max_attempts
.. autofunction:: http_max_attempts

.. autofunction:: set
109 changes: 105 additions & 4 deletions python/kvikio/kvikio/cufile_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,114 @@
# See file LICENSE for terms.

import atexit
from typing import Tuple
from typing import Tuple, Any, overload

from kvikio._lib import cufile_driver # type: ignore
import kvikio.utils

# TODO: Wrap nicely, maybe as a dataclass?
# <https://github.com/rapidsai/kvikio/issues/526>
DriverProperties = cufile_driver.DriverProperties

properties = cufile_driver.DriverProperties()


class ConfigContextManager:
def __init__(self, config: dict[str, str]):
(
self._property_getters,
self._property_setters,
) = self._property_getter_and_setter()
self._old_properties = {}

for key, value in config.items():
self._old_properties[key] = self._get_property(key)
self._set_property(key, value)

def __enter__(self):
return None

def __exit__(self, type_unused, value, traceback_unused):
for key, value in self._old_properties.items():
self._set_property(key, value)

def _get_property(self, property: str) -> Any:
func = self._property_getters[property]

# getter signature: object.__get__(self, instance, owner=None)
return func(properties)

def _set_property(self, property: str, value: Any):
func = self._property_setters[property]

# setter signature: object.__set__(self, instance, value)
func(properties, value)

@kvikio.utils.call_once
def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]:
class_dict = vars(cufile_driver.DriverProperties)

property_getter_names = ["poll_mode",
"poll_thresh_size",
"max_device_cache_size",
"max_pinned_memory_size"]

property_getters = {}
property_setters = {}

for name in property_getter_names:
property_getters[name] = class_dict[name].__get__
property_setters[name] = class_dict[name].__set__
return property_getters, property_setters


@overload
def set(config: dict[str, Any], /) -> ConfigContextManager:
...


@overload
def set(key: str, value: Any, /) -> ConfigContextManager:
...


def set(*config) -> ConfigContextManager:
"""Set cuFile driver configurations.
Examples:
- To set one or more properties
.. code-block:: python
kvikio.cufile_driver.properties.set({"prop1": value1, "prop2": value2})
- To set a single property
.. code-block:: python
kvikio.cufile_driver.properties.set("prop", value)
Parameters
----------
config
The configurations. Can either be a single parameter (dict) consisting of one
or more properties, or two parameters key (string) and value (Any)
indicating a single property.
"""

err_msg = (
"Valid arguments are kvikio.cufile_driver.properties.set(config: dict) or "
"kvikio.cufile_driver.properties.set(key: str, value: Any)"
)

if len(config) == 1:
if not isinstance(config[0], dict):
raise ValueError(err_msg)
return ConfigContextManager(config[0])
elif len(config) == 2:
if not isinstance(config[0], str):
raise ValueError(err_msg)
return ConfigContextManager({config[0]: config[1]})
else:
raise ValueError(err_msg)


def libcufile_version() -> Tuple[int, int]:
Expand Down
106 changes: 41 additions & 65 deletions python/kvikio/kvikio/defaults.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,19 @@
# Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

import re
import warnings
from typing import Any, Callable, overload

import kvikio._lib.defaults


def call_once(func: Callable):
"""Decorate a function such that it is only called once
Examples:
.. code-block:: python
@call_once
foo(args)
Parameters
----------
func: Callable
The function to be decorated.
"""
once_flag = True
cached_result = None

def wrapper(*args, **kwargs):
nonlocal once_flag
nonlocal cached_result
if once_flag:
once_flag = False
cached_result = func(*args, **kwargs)
return cached_result

return wrapper
import kvikio.utils


class ConfigContextManager:
def __init__(self, config: dict[str, str]):
(
self._all_getter_property_functions,
self._all_setter_property_functions,
) = self._all_property_functions()
self._property_getters,
self._property_setters,
) = self._property_getter_and_setter()
self._old_properties = {}

for key, value in config.items():
Expand All @@ -59,29 +30,34 @@ def __exit__(self, type_unused, value, traceback_unused):
def _get_property(self, property: str) -> Any:
if property == "num_threads":
property = "thread_pool_nthreads"
func = self._all_getter_property_functions[property]
func = self._property_getters[property]
return func()

def _set_property(self, property: str, value: Any):
if property == "num_threads":
property = "thread_pool_nthreads"
func = self._all_setter_property_functions[property]
func = self._property_setters[property]
func(value)

@call_once
def _all_property_functions(self) -> tuple[dict[str, Any], dict[str, Any]]:
getter_properties = {}
setter_properties = {}
# Among all attributes of the `kvikio._lib.defaults` module,
# get those whose name start with `set_`.
# Remove the `set_` prefix to obtain the property name.
module_dict = kvikio._lib.defaults.__dict__
for attr_name, attr_obj in module_dict.items():
if re.match("set_", attr_name):
property_name = re.sub("set_", "", attr_name)
getter_properties[property_name] = module_dict[property_name]
setter_properties[property_name] = attr_obj
return getter_properties, setter_properties
@kvikio.utils.call_once
def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]:
module_dict = vars(kvikio._lib.defaults)

property_getter_names = ["compat_mode",
"thread_pool_nthreads",
"task_size",
"gds_threshold",
"bounce_buffer_size",
"http_max_attempts",
"http_status_codes"]

property_getters = {}
property_setters = {}

for name in property_getter_names:
property_getters[name] = module_dict[name]
property_setters[name] = module_dict["set_" + name]
return property_getters, property_setters


@overload
Expand Down Expand Up @@ -265,19 +241,19 @@ def http_status_codes() -> list[int]:


def kvikio_deprecation_notice(msg: str):
def decorator_imp(func: Callable):
def decorator(func: Callable):
def wrapper(*args, **kwargs):
warnings.warn(msg, category=FutureWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapper

return decorator_imp
return decorator


@kvikio_deprecation_notice('Use kvikio.defaults.set("compat_mode", value) instead')
def compat_mode_reset(compatmode: kvikio.CompatMode) -> None:
"""(deprecated) Reset the compatibility mode.
"""(Deprecated) Reset the compatibility mode.
Use this function to enable/disable compatibility mode explicitly.
Expand All @@ -293,13 +269,13 @@ def compat_mode_reset(compatmode: kvikio.CompatMode) -> None:

@kvikio_deprecation_notice('Use kvikio.defaults.set("compat_mode", value) instead')
def set_compat_mode(compatmode: kvikio.CompatMode):
"""(deprecated) Same with compat_mode_reset."""
"""(Deprecated) Same with compat_mode_reset."""
compat_mode_reset(compatmode)


@kvikio_deprecation_notice('Use kvikio.defaults.set("num_threads", value) instead')
def num_threads_reset(nthreads: int) -> None:
"""(deprecated) Reset the number of threads in the default thread pool.
"""(Deprecated) Reset the number of threads in the default thread pool.
Waits for all currently running tasks to be completed, then destroys all threads
in the pool and creates a new thread pool with the new number of threads. Any
Expand All @@ -319,13 +295,13 @@ def num_threads_reset(nthreads: int) -> None:

@kvikio_deprecation_notice('Use kvikio.defaults.set("num_threads", value) instead')
def set_num_threads(nthreads: int):
"""(deprecated) Same with num_threads_reset."""
"""(Deprecated) Same with num_threads_reset."""
set("num_threads", nthreads)


@kvikio_deprecation_notice('Use kvikio.defaults.set("task_size", value) instead')
def task_size_reset(nbytes: int) -> None:
"""(deprecated) Reset the default task size used for parallel IO operations.
"""(Deprecated) Reset the default task size used for parallel IO operations.
Parameters
----------
Expand All @@ -337,13 +313,13 @@ def task_size_reset(nbytes: int) -> None:

@kvikio_deprecation_notice('Use kvikio.defaults.set("task_size", value) instead')
def set_task_size(nbytes: int):
"""(deprecated) Same with task_size_reset."""
"""(Deprecated) Same with task_size_reset."""
set("task_size", nbytes)


@kvikio_deprecation_notice('Use kvikio.defaults.set("gds_threshold", value) instead')
def gds_threshold_reset(nbytes: int) -> None:
"""(deprecated) Reset the default GDS threshold, which is the minimum size to
"""(Deprecated) Reset the default GDS threshold, which is the minimum size to
use GDS.
Parameters
Expand All @@ -356,15 +332,15 @@ def gds_threshold_reset(nbytes: int) -> None:

@kvikio_deprecation_notice('Use kvikio.defaults.set("gds_threshold", value) instead')
def set_gds_threshold(nbytes: int):
"""(deprecated) Same with gds_threshold_reset."""
"""(Deprecated) Same with gds_threshold_reset."""
set("gds_threshold", nbytes)


@kvikio_deprecation_notice(
'Use kvikio.defaults.set("bounce_buffer_size", value) instead'
)
def bounce_buffer_size_reset(nbytes: int) -> None:
"""(deprecated) Reset the size of the bounce buffer used to stage data in host
"""(Deprecated) Reset the size of the bounce buffer used to stage data in host
memory.
Parameters
Expand All @@ -379,15 +355,15 @@ def bounce_buffer_size_reset(nbytes: int) -> None:
'Use kvikio.defaults.set("bounce_buffer_size", value) instead'
)
def set_bounce_buffer_size(nbytes: int):
"""(deprecated) Same with bounce_buffer_size_reset."""
"""(Deprecated) Same with bounce_buffer_size_reset."""
set("bounce_buffer_size", nbytes)


@kvikio_deprecation_notice(
'Use kvikio.defaults.set("http_max_attempts", value) instead'
)
def http_max_attempts_reset(attempts: int) -> None:
"""(deprecated) Reset the maximum number of attempts per remote IO read.
"""(Deprecated) Reset the maximum number of attempts per remote IO read.
Parameters
----------
Expand All @@ -401,15 +377,15 @@ def http_max_attempts_reset(attempts: int) -> None:
'Use kvikio.defaults.set("http_max_attempts", value) instead'
)
def set_http_max_attempts(attempts: int):
"""(deprecated) Same with http_max_attempts_reset."""
"""(Deprecated) Same with http_max_attempts_reset."""
set("http_max_attempts", attempts)


@kvikio_deprecation_notice(
'Use kvikio.defaults.set("http_status_codes", value) instead'
)
def http_status_codes_reset(status_codes: list[int]) -> None:
"""(deprecated) Reset the list of HTTP status codes to retry.
"""(Deprecated) Reset the list of HTTP status codes to retry.
Parameters
----------
Expand All @@ -423,5 +399,5 @@ def http_status_codes_reset(status_codes: list[int]) -> None:
'Use kvikio.defaults.set("http_status_codes", value) instead'
)
def set_http_status_codes(status_codes: list[int]):
"""(deprecated) Same with http_status_codes_reset."""
"""(Deprecated) Same with http_status_codes_reset."""
set("http_status_codes", status_codes)
Loading

0 comments on commit 0f681dd

Please sign in to comment.