From 0f681dd794efabceac68efec1f28646b71adc9d1 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 26 Feb 2025 11:00:31 -0500 Subject: [PATCH] Add context manager for cufile driver properties --- docs/source/api.rst | 2 +- python/kvikio/kvikio/cufile_driver.py | 109 +++++++++++++++++++++- python/kvikio/kvikio/defaults.py | 106 ++++++++------------- python/kvikio/kvikio/utils.py | 34 ++++++- python/kvikio/tests/test_cufile_driver.py | 30 ++++++ python/kvikio/tests/test_defaults.py | 12 +++ 6 files changed, 221 insertions(+), 72 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 2588e909ff..b7907fae9e 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -41,6 +41,6 @@ Defaults .. autofunction:: http_status_codes -.. autofunction:: kvikio.defaults.http_max_attempts +.. autofunction:: http_max_attempts .. autofunction:: set diff --git a/python/kvikio/kvikio/cufile_driver.py b/python/kvikio/kvikio/cufile_driver.py index fb32be347a..59b9e1168f 100644 --- a/python/kvikio/kvikio/cufile_driver.py +++ b/python/kvikio/kvikio/cufile_driver.py @@ -2,13 +2,114 @@ # See file LICENSE for terms. import atexit -from typing import Tuple +from typing import Tuple, Any, overload from kvikio._lib import cufile_driver # type: ignore +import kvikio.utils -# TODO: Wrap nicely, maybe as a dataclass? -# -DriverProperties = cufile_driver.DriverProperties + +properties = cufile_driver.DriverProperties() + + +class ConfigContextManager: + def __init__(self, config: dict[str, str]): + ( + self._property_getters, + self._property_setters, + ) = self._property_getter_and_setter() + self._old_properties = {} + + for key, value in config.items(): + self._old_properties[key] = self._get_property(key) + self._set_property(key, value) + + def __enter__(self): + return None + + def __exit__(self, type_unused, value, traceback_unused): + for key, value in self._old_properties.items(): + self._set_property(key, value) + + def _get_property(self, property: str) -> Any: + func = self._property_getters[property] + + # getter signature: object.__get__(self, instance, owner=None) + return func(properties) + + def _set_property(self, property: str, value: Any): + func = self._property_setters[property] + + # setter signature: object.__set__(self, instance, value) + func(properties, value) + + @kvikio.utils.call_once + def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]: + class_dict = vars(cufile_driver.DriverProperties) + + property_getter_names = ["poll_mode", + "poll_thresh_size", + "max_device_cache_size", + "max_pinned_memory_size"] + + property_getters = {} + property_setters = {} + + for name in property_getter_names: + property_getters[name] = class_dict[name].__get__ + property_setters[name] = class_dict[name].__set__ + return property_getters, property_setters + + +@overload +def set(config: dict[str, Any], /) -> ConfigContextManager: + ... + + +@overload +def set(key: str, value: Any, /) -> ConfigContextManager: + ... + + +def set(*config) -> ConfigContextManager: + """Set cuFile driver configurations. + + Examples: + + - To set one or more properties + + .. code-block:: python + + kvikio.cufile_driver.properties.set({"prop1": value1, "prop2": value2}) + + - To set a single property + + .. code-block:: python + + kvikio.cufile_driver.properties.set("prop", value) + + Parameters + ---------- + config + The configurations. Can either be a single parameter (dict) consisting of one + or more properties, or two parameters key (string) and value (Any) + indicating a single property. + """ + + err_msg = ( + "Valid arguments are kvikio.cufile_driver.properties.set(config: dict) or " + "kvikio.cufile_driver.properties.set(key: str, value: Any)" + ) + + if len(config) == 1: + if not isinstance(config[0], dict): + raise ValueError(err_msg) + return ConfigContextManager(config[0]) + elif len(config) == 2: + if not isinstance(config[0], str): + raise ValueError(err_msg) + return ConfigContextManager({config[0]: config[1]}) + else: + raise ValueError(err_msg) def libcufile_version() -> Tuple[int, int]: diff --git a/python/kvikio/kvikio/defaults.py b/python/kvikio/kvikio/defaults.py index 726d39ae17..68e49891b4 100644 --- a/python/kvikio/kvikio/defaults.py +++ b/python/kvikio/kvikio/defaults.py @@ -1,48 +1,19 @@ # Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. -import re import warnings from typing import Any, Callable, overload import kvikio._lib.defaults - - -def call_once(func: Callable): - """Decorate a function such that it is only called once - - Examples: - - .. code-block:: python - - @call_once - foo(args) - - Parameters - ---------- - func: Callable - The function to be decorated. - """ - once_flag = True - cached_result = None - - def wrapper(*args, **kwargs): - nonlocal once_flag - nonlocal cached_result - if once_flag: - once_flag = False - cached_result = func(*args, **kwargs) - return cached_result - - return wrapper +import kvikio.utils class ConfigContextManager: def __init__(self, config: dict[str, str]): ( - self._all_getter_property_functions, - self._all_setter_property_functions, - ) = self._all_property_functions() + self._property_getters, + self._property_setters, + ) = self._property_getter_and_setter() self._old_properties = {} for key, value in config.items(): @@ -59,29 +30,34 @@ def __exit__(self, type_unused, value, traceback_unused): def _get_property(self, property: str) -> Any: if property == "num_threads": property = "thread_pool_nthreads" - func = self._all_getter_property_functions[property] + func = self._property_getters[property] return func() def _set_property(self, property: str, value: Any): if property == "num_threads": property = "thread_pool_nthreads" - func = self._all_setter_property_functions[property] + func = self._property_setters[property] func(value) - @call_once - def _all_property_functions(self) -> tuple[dict[str, Any], dict[str, Any]]: - getter_properties = {} - setter_properties = {} - # Among all attributes of the `kvikio._lib.defaults` module, - # get those whose name start with `set_`. - # Remove the `set_` prefix to obtain the property name. - module_dict = kvikio._lib.defaults.__dict__ - for attr_name, attr_obj in module_dict.items(): - if re.match("set_", attr_name): - property_name = re.sub("set_", "", attr_name) - getter_properties[property_name] = module_dict[property_name] - setter_properties[property_name] = attr_obj - return getter_properties, setter_properties + @kvikio.utils.call_once + def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]: + module_dict = vars(kvikio._lib.defaults) + + property_getter_names = ["compat_mode", + "thread_pool_nthreads", + "task_size", + "gds_threshold", + "bounce_buffer_size", + "http_max_attempts", + "http_status_codes"] + + property_getters = {} + property_setters = {} + + for name in property_getter_names: + property_getters[name] = module_dict[name] + property_setters[name] = module_dict["set_" + name] + return property_getters, property_setters @overload @@ -265,19 +241,19 @@ def http_status_codes() -> list[int]: def kvikio_deprecation_notice(msg: str): - def decorator_imp(func: Callable): + def decorator(func: Callable): def wrapper(*args, **kwargs): warnings.warn(msg, category=FutureWarning, stacklevel=2) return func(*args, **kwargs) return wrapper - return decorator_imp + return decorator @kvikio_deprecation_notice('Use kvikio.defaults.set("compat_mode", value) instead') def compat_mode_reset(compatmode: kvikio.CompatMode) -> None: - """(deprecated) Reset the compatibility mode. + """(Deprecated) Reset the compatibility mode. Use this function to enable/disable compatibility mode explicitly. @@ -293,13 +269,13 @@ def compat_mode_reset(compatmode: kvikio.CompatMode) -> None: @kvikio_deprecation_notice('Use kvikio.defaults.set("compat_mode", value) instead') def set_compat_mode(compatmode: kvikio.CompatMode): - """(deprecated) Same with compat_mode_reset.""" + """(Deprecated) Same with compat_mode_reset.""" compat_mode_reset(compatmode) @kvikio_deprecation_notice('Use kvikio.defaults.set("num_threads", value) instead') def num_threads_reset(nthreads: int) -> None: - """(deprecated) Reset the number of threads in the default thread pool. + """(Deprecated) Reset the number of threads in the default thread pool. Waits for all currently running tasks to be completed, then destroys all threads in the pool and creates a new thread pool with the new number of threads. Any @@ -319,13 +295,13 @@ def num_threads_reset(nthreads: int) -> None: @kvikio_deprecation_notice('Use kvikio.defaults.set("num_threads", value) instead') def set_num_threads(nthreads: int): - """(deprecated) Same with num_threads_reset.""" + """(Deprecated) Same with num_threads_reset.""" set("num_threads", nthreads) @kvikio_deprecation_notice('Use kvikio.defaults.set("task_size", value) instead') def task_size_reset(nbytes: int) -> None: - """(deprecated) Reset the default task size used for parallel IO operations. + """(Deprecated) Reset the default task size used for parallel IO operations. Parameters ---------- @@ -337,13 +313,13 @@ def task_size_reset(nbytes: int) -> None: @kvikio_deprecation_notice('Use kvikio.defaults.set("task_size", value) instead') def set_task_size(nbytes: int): - """(deprecated) Same with task_size_reset.""" + """(Deprecated) Same with task_size_reset.""" set("task_size", nbytes) @kvikio_deprecation_notice('Use kvikio.defaults.set("gds_threshold", value) instead') def gds_threshold_reset(nbytes: int) -> None: - """(deprecated) Reset the default GDS threshold, which is the minimum size to + """(Deprecated) Reset the default GDS threshold, which is the minimum size to use GDS. Parameters @@ -356,7 +332,7 @@ def gds_threshold_reset(nbytes: int) -> None: @kvikio_deprecation_notice('Use kvikio.defaults.set("gds_threshold", value) instead') def set_gds_threshold(nbytes: int): - """(deprecated) Same with gds_threshold_reset.""" + """(Deprecated) Same with gds_threshold_reset.""" set("gds_threshold", nbytes) @@ -364,7 +340,7 @@ def set_gds_threshold(nbytes: int): 'Use kvikio.defaults.set("bounce_buffer_size", value) instead' ) def bounce_buffer_size_reset(nbytes: int) -> None: - """(deprecated) Reset the size of the bounce buffer used to stage data in host + """(Deprecated) Reset the size of the bounce buffer used to stage data in host memory. Parameters @@ -379,7 +355,7 @@ def bounce_buffer_size_reset(nbytes: int) -> None: 'Use kvikio.defaults.set("bounce_buffer_size", value) instead' ) def set_bounce_buffer_size(nbytes: int): - """(deprecated) Same with bounce_buffer_size_reset.""" + """(Deprecated) Same with bounce_buffer_size_reset.""" set("bounce_buffer_size", nbytes) @@ -387,7 +363,7 @@ def set_bounce_buffer_size(nbytes: int): 'Use kvikio.defaults.set("http_max_attempts", value) instead' ) def http_max_attempts_reset(attempts: int) -> None: - """(deprecated) Reset the maximum number of attempts per remote IO read. + """(Deprecated) Reset the maximum number of attempts per remote IO read. Parameters ---------- @@ -401,7 +377,7 @@ def http_max_attempts_reset(attempts: int) -> None: 'Use kvikio.defaults.set("http_max_attempts", value) instead' ) def set_http_max_attempts(attempts: int): - """(deprecated) Same with http_max_attempts_reset.""" + """(Deprecated) Same with http_max_attempts_reset.""" set("http_max_attempts", attempts) @@ -409,7 +385,7 @@ def set_http_max_attempts(attempts: int): 'Use kvikio.defaults.set("http_status_codes", value) instead' ) def http_status_codes_reset(status_codes: list[int]) -> None: - """(deprecated) Reset the list of HTTP status codes to retry. + """(Deprecated) Reset the list of HTTP status codes to retry. Parameters ---------- @@ -423,5 +399,5 @@ def http_status_codes_reset(status_codes: list[int]) -> None: 'Use kvikio.defaults.set("http_status_codes", value) instead' ) def set_http_status_codes(status_codes: list[int]): - """(deprecated) Same with http_status_codes_reset.""" + """(Deprecated) Same with http_status_codes_reset.""" set("http_status_codes", status_codes) diff --git a/python/kvikio/kvikio/utils.py b/python/kvikio/kvikio/utils.py index fc88e321a5..6a03421d61 100644 --- a/python/kvikio/kvikio/utils.py +++ b/python/kvikio/kvikio/utils.py @@ -11,7 +11,7 @@ SimpleHTTPRequestHandler, ThreadingHTTPServer, ) -from typing import Any +from typing import Any, Callable class LocalHttpServer: @@ -79,7 +79,8 @@ def __enter__(self): else: handler = SimpleHTTPRequestHandler - handler_options = {**self.handler_options, **{"directory": self.root_path}} + handler_options = {**self.handler_options, + **{"directory": self.root_path}} self.process = multiprocessing.Process( target=LocalHttpServer._server, @@ -94,3 +95,32 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.process.kill() + + +def call_once(func: Callable): + """Decorate a function such that it is only called once + + Examples: + + .. code-block:: python + + @call_once + foo(args) + + Parameters + ---------- + func: Callable + The function to be decorated. + """ + once_flag = True + cached_result = None + + def wrapper(*args, **kwargs): + nonlocal once_flag + nonlocal cached_result + if once_flag: + once_flag = False + cached_result = func(*args, **kwargs) + return cached_result + + return wrapper diff --git a/python/kvikio/tests/test_cufile_driver.py b/python/kvikio/tests/test_cufile_driver.py index a1dc3a6454..4602a8a2ba 100644 --- a/python/kvikio/tests/test_cufile_driver.py +++ b/python/kvikio/tests/test_cufile_driver.py @@ -16,3 +16,33 @@ def test_version(): def test_open_and_close(): kvikio.cufile_driver.driver_open() kvikio.cufile_driver.driver_close() + + +def test_property_setter(): + """Test the method `set`""" + + # Attempt to set a nonexistent property + with pytest.raises(KeyError): + kvikio.cufile_driver.set("nonexistent_property", 123) + + # Nested context managers + poll_thresh_size_default = kvikio.cufile_driver.properties.poll_thresh_size + with kvikio.cufile_driver.set("poll_thresh_size", 1024): + assert kvikio.cufile_driver.properties.poll_thresh_size == 1024 + with kvikio.cufile_driver.set("poll_thresh_size", 2048): + assert kvikio.cufile_driver.properties.poll_thresh_size == 2048 + with kvikio.cufile_driver.set("poll_thresh_size", 4096): + assert kvikio.cufile_driver.properties.poll_thresh_size == 4096 + assert kvikio.cufile_driver.properties.poll_thresh_size == 2048 + assert kvikio.cufile_driver.properties.poll_thresh_size == 1024 + assert kvikio.cufile_driver.properties.poll_thresh_size == poll_thresh_size_default + + # Multiple context managers + poll_mode_default = kvikio.cufile_driver.properties.poll_mode + max_device_cache_size_default = kvikio.cufile_driver.properties.max_device_cache_size + with kvikio.cufile_driver.set({"poll_mode": True, "max_device_cache_size": 2048}): + assert kvikio.cufile_driver.properties.poll_mode and\ + (kvikio.cufile_driver.properties.max_device_cache_size == 2048) + assert (kvikio.cufile_driver.properties.poll_mode == poll_mode_default) and\ + (kvikio.cufile_driver.properties.max_device_cache_size == + max_device_cache_size_default) diff --git a/python/kvikio/tests/test_defaults.py b/python/kvikio/tests/test_defaults.py index 7b6854a6c1..098d2657e5 100644 --- a/python/kvikio/tests/test_defaults.py +++ b/python/kvikio/tests/test_defaults.py @@ -31,6 +31,18 @@ def test_property_setter(): assert kvikio.defaults.task_size() == 1024 assert kvikio.defaults.task_size() == task_size_default + # Multiple context managers + task_size_default = kvikio.defaults.task_size() + num_threads_default = kvikio.defaults.num_threads() + bounce_buffer_size_default = kvikio.defaults.bounce_buffer_size() + with kvikio.defaults.set({"task_size": 1024, "num_threads": 16, "bounce_buffer_size": 1024}): + assert (kvikio.defaults.task_size() == 1024) and\ + (kvikio.defaults.num_threads() == 16) and\ + (kvikio.defaults.bounce_buffer_size() == 1024) + assert (kvikio.defaults.task_size() == task_size_default) and\ + (kvikio.defaults.num_threads() == num_threads_default) and\ + (kvikio.defaults.bounce_buffer_size() == bounce_buffer_size_default) + @pytest.mark.skipif( kvikio.defaults.compat_mode() == kvikio.CompatMode.ON,