diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index 8471e29a1e..43dd9a4764 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -29,6 +29,7 @@ dependencies: - pre-commit - pytest - pytest-cov +- pytest-timeout - python>=3.10,<3.13 - rangehttpserver - rapids-build-backend>=0.3.0,<0.4.0.dev0 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index a42d41f746..5d42a932a3 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -31,6 +31,7 @@ dependencies: - pre-commit - pytest - pytest-cov +- pytest-timeout - python>=3.10,<3.13 - rangehttpserver - rapids-build-backend>=0.3.0,<0.4.0.dev0 diff --git a/conda/environments/all_cuda-128_arch-aarch64.yaml b/conda/environments/all_cuda-128_arch-aarch64.yaml index 4f184129bc..16a54ffc36 100644 --- a/conda/environments/all_cuda-128_arch-aarch64.yaml +++ b/conda/environments/all_cuda-128_arch-aarch64.yaml @@ -29,6 +29,7 @@ dependencies: - pre-commit - pytest - pytest-cov +- pytest-timeout - python>=3.10,<3.13 - rangehttpserver - rapids-build-backend>=0.3.0,<0.4.0.dev0 diff --git a/conda/environments/all_cuda-128_arch-x86_64.yaml b/conda/environments/all_cuda-128_arch-x86_64.yaml index d3df1235b6..df936afae7 100644 --- a/conda/environments/all_cuda-128_arch-x86_64.yaml +++ b/conda/environments/all_cuda-128_arch-x86_64.yaml @@ -29,6 +29,7 @@ dependencies: - pre-commit - pytest - pytest-cov +- pytest-timeout - python>=3.10,<3.13 - rangehttpserver - rapids-build-backend>=0.3.0,<0.4.0.dev0 diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ebc6b51f10..69c7891403 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -139,6 +139,7 @@ set(SOURCES "src/bounce_buffer.cpp" "src/buffer.cpp" "src/compat_mode.cpp" + "src/http_status_codes.cpp" "src/cufile/config.cpp" "src/cufile/driver.cpp" "src/defaults.cpp" diff --git a/cpp/doxygen/main_page.md b/cpp/doxygen/main_page.md index a5e9e9162d..aa1175bced 100644 --- a/cpp/doxygen/main_page.md +++ b/cpp/doxygen/main_page.md @@ -107,10 +107,30 @@ To improve performance of small IO requests, `.pread()` and `.pwrite()` implemen This setting can also be controlled by `defaults::gds_threshold()` and `defaults::gds_threshold_reset()`. #### Size of the Bounce Buffer (KVIKIO_GDS_THRESHOLD) -KvikIO might have to use intermediate host buffers (one per thread) when copying between files and device memory. Set the environment variable ``KVIKIO_BOUNCE_BUFFER_SIZE`` to the size (in bytes) of these "bounce" buffers. If not set, the default value is 16777216 (16 MiB). +KvikIO might have to use intermediate host buffers (one per thread) when copying between files and device memory. Set the environment variable `KVIKIO_BOUNCE_BUFFER_SIZE` to the size (in bytes) of these "bounce" buffers. If not set, the default value is 16777216 (16 MiB). This setting can also be controlled by `defaults::bounce_buffer_size()` and `defaults::bounce_buffer_size_reset()`. +#### HTTP Retries + +The behavior when a remote IO read returns a error can be controlled through the `KVIKIO_HTTP_STATUS_CODES` and `KVIKIO_HTTP_MAX_ATTEMPTS` environment variables. +`KVIKIO_HTTP_STATUS_CODES` controls the status codes to retry, and `KVIKIO_HTTP_MAX_ATTEMPTS` controls the maximum number of attempts to make before throwing an exception. + +When a response with a status code in the list of retryable codes is received, KvikIO will wait for some period of time before retrying the request. +It will keep retrying until reaching the maximum number of attempts. + +By default, KvikIO will retry responses with the following status codes: + +- 429 +- 500 +- 502 +- 503 +- 504 + +KvikIO will, by default, make three attempts per read. +Note that if you're reading a large file that has been split into multiple reads through the KvikIO's task size setting, then *each* task will be retried up to the maximum number of attempts. + +These settings can also be controlled by `defaults::http_max_attempts()`, `defaults::http_max_attempts_reset()`, `defaults::http_status_codes()`, and `defaults::http_status_codes_reset()`. ## Example diff --git a/cpp/include/kvikio/defaults.hpp b/cpp/include/kvikio/defaults.hpp index 999de3fb90..4334549d23 100644 --- a/cpp/include/kvikio/defaults.hpp +++ b/cpp/include/kvikio/defaults.hpp @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -53,6 +54,9 @@ bool getenv_or(std::string_view env_var_name, bool default_val); template <> CompatMode getenv_or(std::string_view env_var_name, CompatMode default_val); +template <> +std::vector getenv_or(std::string_view env_var_name, std::vector default_val); + /** * @brief Singleton class of default values used throughout KvikIO. * @@ -64,6 +68,8 @@ class defaults { std::size_t _task_size; std::size_t _gds_threshold; std::size_t _bounce_buffer_size; + std::size_t _http_max_attempts; + std::vector _http_status_codes; static unsigned int get_num_threads_from_env(); @@ -153,7 +159,7 @@ class defaults { * always use the same thread pool however it is possible to change number of * threads in the pool (see `kvikio::default::thread_pool_nthreads_reset()`). * - * @return The the default thread pool instance. + * @return The default thread pool instance. */ [[nodiscard]] static BS_thread_pool& thread_pool(); @@ -230,6 +236,47 @@ class defaults { * @param nbytes The bounce buffer size in bytes. */ static void bounce_buffer_size_reset(std::size_t nbytes); + + /** + * @brief Get the maximum number of attempts per remote IO read. + * + * Set the value using `kvikio::default::http_max_attempts_reset()` or by setting + * the `KVIKIO_HTTP_MAX_ATTEMPTS` environment variable. If not set, the value is 3. + * + * @return The maximum number of remote IO reads to attempt before raising an + * error. + */ + [[nodiscard]] static std::size_t http_max_attempts(); + + /** + * @brief Reset the maximum number of attempts per remote IO read. + * + * @param attempts The maximum number of attempts to try before raising an error. + */ + static void http_max_attempts_reset(std::size_t attempts); + + /** + * @brief The list of HTTP status codes to retry. + * + * Set the value using `kvikio::default::http_status_codes()` or by setting the + * `KVIKIO_HTTP_STATUS_CODES` environment variable. If not set, the default value is + * + * - 429 + * - 500 + * - 502 + * - 503 + * - 504 + * + * @return The list of HTTP status codes to retry. + */ + [[nodiscard]] static std::vector const& http_status_codes(); + + /** + * @brief Reset the list of HTTP status codes to retry. + * + * @param status_codes The HTTP status codes to retry. + */ + static void http_status_codes_reset(std::vector status_codes); }; } // namespace kvikio diff --git a/cpp/include/kvikio/http_status_codes.hpp b/cpp/include/kvikio/http_status_codes.hpp new file mode 100644 index 0000000000..98ffb52324 --- /dev/null +++ b/cpp/include/kvikio/http_status_codes.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace kvikio { +namespace detail { +/** + * @brief Parse a string of comma-separated string of HTTP status codes. + * + * @param env_var_name The environment variable holding the string. + * Used to report errors. + * @param status_codes The comma-separated string of HTTP status + * codes. Each code should be a 3-digit integer. + * + * @return The vector with the parsed, integer HTTP status codes. + */ +std::vector parse_http_status_codes(std::string_view env_var_name, + std::string const& status_codes); +} // namespace detail + +} // namespace kvikio diff --git a/cpp/src/defaults.cpp b/cpp/src/defaults.cpp index fccc89e64d..e0a908cf4d 100644 --- a/cpp/src/defaults.cpp +++ b/cpp/src/defaults.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -24,10 +25,10 @@ #include #include +#include #include namespace kvikio { - template <> bool getenv_or(std::string_view env_var_name, bool default_val) { @@ -68,6 +69,17 @@ CompatMode getenv_or(std::string_view env_var_name, CompatMode default_val) return detail::parse_compat_mode_str(env_val); } +template <> +std::vector getenv_or(std::string_view env_var_name, std::vector default_val) +{ + auto* const env_val = std::getenv(env_var_name.data()); + if (env_val == nullptr) { return std::move(default_val); } + std::string const int_str(env_val); + if (int_str.empty()) { return std::move(default_val); } + + return detail::parse_http_status_codes(env_var_name, int_str); +} + unsigned int defaults::get_num_threads_from_env() { int const ret = getenv_or("KVIKIO_NTHREADS", 1); @@ -109,6 +121,19 @@ defaults::defaults() } _bounce_buffer_size = env; } + // Determine the default value of `http_max_attempts` + { + ssize_t const env = getenv_or("KVIKIO_HTTP_MAX_ATTEMPTS", 3); + if (env <= 0) { + throw std::invalid_argument("KVIKIO_HTTP_MAX_ATTEMPTS has to be a positive integer"); + } + _http_max_attempts = env; + } + // Determine the default value of `http_status_codes` + { + _http_status_codes = + getenv_or("KVIKIO_HTTP_STATUS_CODES", std::vector{429, 500, 502, 503, 504}); + } } defaults* defaults::instance() @@ -177,4 +202,19 @@ void defaults::bounce_buffer_size_reset(std::size_t nbytes) instance()->_bounce_buffer_size = nbytes; } +std::size_t defaults::http_max_attempts() { return instance()->_http_max_attempts; } + +void defaults::http_max_attempts_reset(std::size_t attempts) +{ + if (attempts == 0) { throw std::invalid_argument("attempts must be a positive integer"); } + instance()->_http_max_attempts = attempts; +} + +std::vector const& defaults::http_status_codes() { return instance()->_http_status_codes; } + +void defaults::http_status_codes_reset(std::vector status_codes) +{ + instance()->_http_status_codes = std::move(status_codes); +} + } // namespace kvikio diff --git a/cpp/src/http_status_codes.cpp b/cpp/src/http_status_codes.cpp new file mode 100644 index 0000000000..108df81265 --- /dev/null +++ b/cpp/src/http_status_codes.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace kvikio { + +namespace detail { +std::vector parse_http_status_codes(std::string_view env_var_name, + std::string const& status_codes) +{ + // Ensure `status_codes` consists only of 3-digit integers separated by commas, allowing spaces. + std::regex const check_pattern(R"(^\s*\d{3}\s*(\s*,\s*\d{3}\s*)*$)"); + if (!std::regex_match(status_codes, check_pattern)) { + throw std::invalid_argument(std::string{env_var_name} + + ": invalid format, expected comma-separated integers."); + } + + // Match every integer in `status_codes`. + std::regex const number_pattern(R"(\d+)"); + + // For each match, we push_back `std::stoi(match.str())` into `ret`. + std::vector ret; + std::transform(std::sregex_iterator(status_codes.begin(), status_codes.end(), number_pattern), + std::sregex_iterator(), + std::back_inserter(ret), + [](std::smatch const& match) -> int { return std::stoi(match.str()); }); + return ret; +} + +} // namespace detail + +} // namespace kvikio diff --git a/cpp/src/shim/libcurl.cpp b/cpp/src/shim/libcurl.cpp index 655a7f70fc..05b6e02d10 100644 --- a/cpp/src/shim/libcurl.cpp +++ b/cpp/src/shim/libcurl.cpp @@ -14,12 +14,15 @@ * limitations under the License. */ +#include #include #include +#include #include #include #include #include +#include #include #include @@ -116,19 +119,62 @@ CURL* CurlHandle::handle() noexcept { return _handle.get(); } void CurlHandle::perform() { - // Perform the curl operation and check for errors. - CURLcode err = curl_easy_perform(handle()); - if (err != CURLE_OK) { - std::string msg(_errbuf); // We can do this because we always initialize `_errbuf` as empty. - std::stringstream ss; - ss << "curl_easy_perform() error near " << _source_file << ":" << _source_line; - if (msg.empty()) { - ss << "(" << curl_easy_strerror(err) << ")"; + long http_code = 0; + auto attempt_count = 0; + auto base_delay = 500; // milliseconds + auto max_delay = 4000; // milliseconds + auto http_max_attempts = kvikio::defaults::http_max_attempts(); + auto& http_status_codes = kvikio::defaults::http_status_codes(); + + while (attempt_count++ < http_max_attempts) { + auto err = curl_easy_perform(handle()); + + if (err == CURLE_OK) { + // We set CURLE_HTTP_RETURNED_ERROR, so >= 400 status codes are considered + // errors, so anything less than this is considered a success and we're + // done. + return; + } + // We had an error. Is it retryable? + curl_easy_getinfo(handle(), CURLINFO_RESPONSE_CODE, &http_code); + auto const is_retryable_response = + (std::find(http_status_codes.begin(), http_status_codes.end(), http_code) != + http_status_codes.end()); + + if (is_retryable_response) { + // backoff and retry again. With a base value of 500ms, we retry after + // 500ms, 1s, 2s, 4s, ... + auto const backoff_delay = base_delay * (1 << std::min(attempt_count - 1, 4)); + // up to a maximum of `max_delay` seconds. + auto const delay = std::min(max_delay, backoff_delay); + + // Only print this message out and sleep if we're actually going to retry again. + if (attempt_count < http_max_attempts) { + std::cout << "KvikIO: Got HTTP code " << http_code << ". Retrying after " << delay + << "ms (attempt " << attempt_count << " of " << http_max_attempts << ")." + << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(delay)); + } } else { - ss << "(" << msg << ")"; + // We had some kind of fatal error, or we got some status code we don't retry. + // We want to exit immediately. + std::string msg(_errbuf); // We can do this because we always initialize `_errbuf` as empty. + std::stringstream ss; + ss << "curl_easy_perform() error near " << _source_file << ":" << _source_line; + if (msg.empty()) { + ss << "(" << curl_easy_strerror(err) << ")"; + } else { + ss << "(" << msg << ")"; + } + throw std::runtime_error(ss.str()); } - throw std::runtime_error(ss.str()); } -} + // We've exceeded the maximum number of requests. Fail with a good error + // message. + std::stringstream ss; + ss << "KvikIO: HTTP request reached maximum number of attempts (" << http_max_attempts + << "). Got HTTP code " << http_code << "."; + throw std::runtime_error(ss.str()); +} } // namespace kvikio diff --git a/cpp/tests/test_defaults.cpp b/cpp/tests/test_defaults.cpp index c95e9d1d11..99a597032d 100644 --- a/cpp/tests/test_defaults.cpp +++ b/cpp/tests/test_defaults.cpp @@ -51,3 +51,24 @@ TEST(Defaults, parse_compat_mode_str) } } } + +TEST(Defaults, parse_http_status_codes) +{ + { + std::vector inputs{ + "429,500", "429, 500", " 429,500", "429, 500", "429 ,500", "429,500 "}; + std::vector expected = {429, 500}; + for (const auto& input : inputs) { + EXPECT_EQ(kvikio::detail::parse_http_status_codes("KVIKIO_HTTP_STATUS_CODES", input), + expected); + } + } + + { + std::vector inputs{"429,", ",429", "a,b", "429,,500", "429,1000"}; + for (const auto& input : inputs) { + EXPECT_THROW(kvikio::detail::parse_http_status_codes("KVIKIO_HTTP_STATUS_CODES", input), + std::invalid_argument); + } + } +} diff --git a/dependencies.yaml b/dependencies.yaml index e951fe78cd..3da0c3fdc2 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -421,6 +421,7 @@ dependencies: - rapids-dask-dependency==25.4.*,>=0.0.0a0 - pytest - pytest-cov + - pytest-timeout - rangehttpserver - boto3>=1.21.21 - output_types: [requirements, pyproject] diff --git a/docs/source/runtime_settings.rst b/docs/source/runtime_settings.rst index 0ce1ab7972..5847c1ffbe 100644 --- a/docs/source/runtime_settings.rst +++ b/docs/source/runtime_settings.rst @@ -42,3 +42,12 @@ Size of the Bounce Buffer ``KVIKIO_BOUNCE_BUFFER_SIZE`` KvikIO might have to use intermediate host buffers (one per thread) when copying between files and device memory. Set the environment variable ``KVIKIO_BOUNCE_BUFFER_SIZE`` to the size (in bytes) of these "bounce" buffers. If not set, the default value is 16777216 (16 MiB). This setting can also be controlled by :py:func:`kvikio.defaults.bounce_buffer_size`, :py:func:`kvikio.defaults.bounce_buffer_size_reset`, and :py:func:`kvikio.defaults.set_bounce_buffer_size`. + +#### HTTP Retries +----------------- + +The behavior when a remote IO read returns a error can be controlled through the `KVIKIO_HTTP_STATUS_CODES` and `KVIKIO_HTTP_MAX_ATTEMPTS` environment variables. + +`KVIKIO_HTTP_STATUS_CODES` controls the status codes to retry and can be controlled by :py:func:`kvikio.defaults.http_status_codes`, :py:func:`kvikio.defaults.http_status_codes_reset`, and :py:func:`kvikio.defaults.set_http_status_codes`. + +`KVIKIO_HTTP_MAX_ATTEMPTS` controls the maximum number of attempts to make before throwing an exception and can be controlled by :py:func:`kvikio.defaults.http_max_attempts`, :py:func:`kvikio.defaults.http_max_attempts_reset`, and :py:func:`kvikio.defaults.set_http_max_attempts`. diff --git a/python/kvikio/kvikio/_lib/defaults.pyx b/python/kvikio/kvikio/_lib/defaults.pyx index 9042069b74..0770cb557a 100644 --- a/python/kvikio/kvikio/_lib/defaults.pyx +++ b/python/kvikio/kvikio/_lib/defaults.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. # distutils: language = c++ @@ -6,6 +6,7 @@ from libc.stdint cimport uint8_t from libcpp cimport bool +from libcpp.vector cimport vector cdef extern from "" namespace "kvikio" nogil: @@ -28,6 +29,12 @@ cdef extern from "" namespace "kvikio" nogil: size_t cpp_bounce_buffer_size "kvikio::defaults::bounce_buffer_size"() except + void cpp_bounce_buffer_size_reset \ "kvikio::defaults::bounce_buffer_size_reset"(size_t nbytes) except + + size_t cpp_http_max_attempts "kvikio::defaults::http_max_attempts"() except + + void cpp_http_max_attempts_reset \ + "kvikio::defaults::http_max_attempts_reset"(size_t attempts) except + + vector[int] cpp_http_status_codes "kvikio::defaults::http_status_codes"() except + + void cpp_http_status_codes_reset \ + "kvikio::defaults::http_status_codes_reset"(vector[int] status_codes) except + def compat_mode() -> CompatMode: @@ -68,3 +75,19 @@ def bounce_buffer_size() -> int: def bounce_buffer_size_reset(nbytes: int) -> None: cpp_bounce_buffer_size_reset(nbytes) + + +def http_max_attempts() -> int: + return cpp_http_max_attempts() + + +def http_max_attempts_reset(attempts: int) -> None: + cpp_http_max_attempts_reset(attempts) + + +def http_status_codes() -> list[int]: + return cpp_http_status_codes() + + +def http_status_codes_reset(status_codes: list[int]) -> None: + return cpp_http_status_codes_reset(status_codes) diff --git a/python/kvikio/kvikio/defaults.py b/python/kvikio/kvikio/defaults.py index 9e959c1f74..4201cc29a3 100644 --- a/python/kvikio/kvikio/defaults.py +++ b/python/kvikio/kvikio/defaults.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. @@ -214,8 +214,8 @@ def bounce_buffer_size() -> int: `KVIKIO_BOUNCE_BUFFER_SIZE` environment variable. If not set, the value is 16 MiB. - Return - ------ + Returns + ------- nbytes : int The bounce buffer size in bytes. """ @@ -235,7 +235,7 @@ def bounce_buffer_size_reset(nbytes: int) -> None: @contextlib.contextmanager def set_bounce_buffer_size(nbytes: int): - """Context for resetting the the size of the bounce buffer. + """Context for resetting the size of the bounce buffer. Parameters ---------- @@ -248,3 +248,99 @@ def set_bounce_buffer_size(nbytes: int): yield finally: bounce_buffer_size_reset(old_value) + + +def http_max_attempts() -> int: + """Get the maximum number of attempts per remote IO read. + + Reads are retried up until ``http_max_attempts`` when the response has certain + HTTP status codes. + + Set the value using `http_max_attempts_reset()` or by setting the + ``KVIKIO_HTTP_MAX_ATTEMPTS`` environment variable. If not set, the + value is 3. + + Returns + ------- + max_attempts : int + The maximum number of remote IO reads to attempt before raising an + error. + """ + return kvikio._lib.defaults.http_max_attempts() + + +def http_max_attempts_reset(attempts: int) -> None: + """Reset the maximum number of attempts per remote IO read. + + Parameters + ---------- + attempts : int + The maximum number of attempts to try before raising an error. + """ + kvikio._lib.defaults.http_max_attempts_reset(attempts) + + +@contextlib.contextmanager +def set_http_max_attempts(attempts: int): + """Context for resetting the maximum number of HTTP attempts. + + Parameters + ---------- + attempts : int + The maximum number of attempts to try before raising an error. + """ + old_value = http_max_attempts() + try: + http_max_attempts_reset(attempts) + yield + finally: + http_max_attempts_reset(old_value) + + +def http_status_codes() -> list[int]: + """Get the list of HTTP status codes to retry. + + Set the value using ``set_http_status_codes`` or by setting the + ``KVIKIO_HTTP_STATUS_CODES`` environment variable. If not set, the + default value is + + - 429 + - 500 + - 502 + - 503 + - 504 + + Returns + ------- + status_codes : list[int] + The HTTP status codes to retry. + """ + return kvikio._lib.defaults.http_status_codes() + + +def http_status_codes_reset(status_codes: list[int]) -> None: + """Reset the list of HTTP status codes to retry. + + Parameters + ---------- + status_codes : list[int] + The HTTP status codes to retry. + """ + kvikio._lib.defaults.http_status_codes_reset(status_codes) + + +@contextlib.contextmanager +def set_http_status_codes(status_codes: list[int]): + """Context for resetting the HTTP status codes to retry. + + Parameters + ---------- + status_codes : list[int] + THe HTTP status codes to retry. + """ + old_value = http_status_codes() + try: + http_status_codes_reset(status_codes) + yield + finally: + http_status_codes_reset(old_value) diff --git a/python/kvikio/kvikio/utils.py b/python/kvikio/kvikio/utils.py index 09a9f2062a..fc88e321a5 100644 --- a/python/kvikio/kvikio/utils.py +++ b/python/kvikio/kvikio/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. import functools @@ -6,7 +6,12 @@ import pathlib import threading import time -from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer +from http.server import ( + BaseHTTPRequestHandler, + SimpleHTTPRequestHandler, + ThreadingHTTPServer, +) +from typing import Any class LocalHttpServer: @@ -15,18 +20,12 @@ class LocalHttpServer: @staticmethod def _server( queue: multiprocessing.Queue, - root_path: str, - range_support: bool, + handler: type[BaseHTTPRequestHandler], + handler_options: dict[str, Any], max_lifetime: int, ): - if range_support: - from RangeHTTPServer import RangeRequestHandler - - handler = RangeRequestHandler - else: - handler = SimpleHTTPRequestHandler httpd = ThreadingHTTPServer( - ("127.0.0.1", 0), functools.partial(handler, directory=root_path) + ("127.0.0.1", 0), functools.partial(handler, **handler_options) ) thread = threading.Thread(target=httpd.serve_forever) thread.start() @@ -41,6 +40,8 @@ def __init__( root_path: str | pathlib.Path, range_support: bool = True, max_lifetime: int = 120, + handler: type[BaseHTTPRequestHandler] | None = None, + handler_options: dict[str, Any] | None = None, ) -> None: """Create a context that starts a local http server. @@ -63,12 +64,26 @@ def __init__( self.root_path = root_path self.range_support = range_support self.max_lifetime = max_lifetime + self.handler = handler + self.handler_options = handler_options or {} def __enter__(self): queue = multiprocessing.Queue() + + if self.handler is not None: + handler = self.handler + elif self.range_support: + from RangeHTTPServer import RangeRequestHandler + + handler = RangeRequestHandler + else: + handler = SimpleHTTPRequestHandler + + handler_options = {**self.handler_options, **{"directory": self.root_path}} + self.process = multiprocessing.Process( target=LocalHttpServer._server, - args=(queue, str(self.root_path), self.range_support, self.max_lifetime), + args=(queue, handler, handler_options, self.max_lifetime), ) self.process.start() ip, port = queue.get() diff --git a/python/kvikio/pyproject.toml b/python/kvikio/pyproject.toml index 53963b4ba5..5548e579e3 100644 --- a/python/kvikio/pyproject.toml +++ b/python/kvikio/pyproject.toml @@ -45,6 +45,7 @@ test = [ "moto[server]>=4.0.8", "pytest", "pytest-cov", + "pytest-timeout", "rangehttpserver", "rapids-dask-dependency==25.4.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. diff --git a/python/kvikio/tests/test_benchmarks.py b/python/kvikio/tests/test_benchmarks.py index 307b0b258d..6707a86efc 100644 --- a/python/kvikio/tests/test_benchmarks.py +++ b/python/kvikio/tests/test_benchmarks.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. import os @@ -28,6 +28,7 @@ "zarr", ], ) +@pytest.mark.timeout(30, method="thread") def test_single_node_io(run_cmd, tmp_path, api): """Test benchmarks/single_node_io.py""" @@ -59,6 +60,7 @@ def test_single_node_io(run_cmd, tmp_path, api): "posix", ], ) +@pytest.mark.timeout(30, method="thread") def test_zarr_io(run_cmd, tmp_path, api): """Test benchmarks/zarr_io.py""" @@ -89,6 +91,7 @@ def test_zarr_io(run_cmd, tmp_path, api): "numpy", ], ) +@pytest.mark.timeout(30, method="thread") def test_http_io(run_cmd, api): """Test benchmarks/http_io.py""" @@ -118,6 +121,7 @@ def test_http_io(run_cmd, api): "numpy", ], ) +@pytest.mark.timeout(30, method="thread") def test_s3_io(run_cmd, api): """Test benchmarks/s3_io.py""" diff --git a/python/kvikio/tests/test_defaults.py b/python/kvikio/tests/test_defaults.py index d7048c418d..5e2d6675f8 100644 --- a/python/kvikio/tests/test_defaults.py +++ b/python/kvikio/tests/test_defaults.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. @@ -84,3 +84,34 @@ def test_bounce_buffer_size(): kvikio.defaults.bounce_buffer_size_reset(0) with pytest.raises(OverflowError, match="negative value"): kvikio.defaults.bounce_buffer_size_reset(-1) + + +def test_http_max_attempts(): + before = kvikio.defaults.http_max_attempts() + + with kvikio.defaults.set_http_max_attempts(5): + assert kvikio.defaults.http_max_attempts() == 5 + kvikio.defaults.http_max_attempts_reset(4) + assert kvikio.defaults.http_max_attempts() == 4 + assert kvikio.defaults.http_max_attempts() == before + + with pytest.raises(ValueError, match="positive integer"): + kvikio.defaults.http_max_attempts_reset(0) + with pytest.raises(OverflowError, match="negative value"): + kvikio.defaults.http_max_attempts_reset(-1) + + +def test_http_status_codes(): + before = kvikio.defaults.http_status_codes() + + with kvikio.defaults.set_http_status_codes([500]): + assert kvikio.defaults.http_status_codes() == [500] + kvikio.defaults.http_status_codes_reset([429, 500]) + assert kvikio.defaults.http_status_codes() == [429, 500] + assert kvikio.defaults.http_status_codes() == before + + with pytest.raises(TypeError): + kvikio.defaults.http_status_codes_reset(0) + + with pytest.raises(TypeError): + kvikio.defaults.http_status_codes_reset(["a"]) diff --git a/python/kvikio/tests/test_http_io.py b/python/kvikio/tests/test_http_io.py index 5c2c3888cd..e62dbb81af 100644 --- a/python/kvikio/tests/test_http_io.py +++ b/python/kvikio/tests/test_http_io.py @@ -1,7 +1,11 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. +import http +from http.server import SimpleHTTPRequestHandler +from typing import Literal + import numpy as np import pytest @@ -18,6 +22,58 @@ ) +class ErrorCounter: + # ThreadedHTTPServer creates a new handler per request. + # This lets us share some state between requests. + def __init__(self): + self.value = 0 + + +class HTTP503Handler(SimpleHTTPRequestHandler): + """ + An HTTP handler that initially responds with a 503 before responding normally. + + Parameters + ---------- + error_counter : ErrorCounter + A class with a mutable `value` for the number of 503 errors that have + been returned. + max_error_count : int + The number of times to respond with a 503 before responding normally. + """ + + def __init__( + self, + *args, + directory=None, + error_counter: ErrorCounter = ErrorCounter(), + max_error_count: int = 1, + **kwargs, + ): + self.max_error_count = max_error_count + self.error_counter = error_counter + super().__init__(*args, directory=directory, **kwargs) + + def _do_with_error_count(self, method: Literal["GET", "HEAD"]) -> None: + if self.error_counter.value < self.max_error_count: + self.error_counter.value += 1 + self.send_error(http.HTTPStatus.SERVICE_UNAVAILABLE) + self.send_header("CurrentErrorCount", str(self.error_counter.value)) + self.send_header("MaxErrorCount", str(self.max_error_count)) + return None + else: + if method == "GET": + return super().do_GET() + else: + return super().do_HEAD() + + def do_GET(self) -> None: + return self._do_with_error_count("GET") + + def do_HEAD(self) -> None: + return self._do_with_error_count("HEAD") + + @pytest.fixture def http_server(request, tmpdir): """Fixture to set up http server in separate process""" @@ -100,3 +156,80 @@ def test_no_range_support(http_server, tmpdir, xp): OverflowError, match="maybe the server doesn't support file ranges?" ): f.read(b, size=10, file_offset=10) + + +def test_retry_http_503_ok(tmpdir, xp): + a = xp.arange(100, dtype="uint8") + a.tofile(tmpdir / "a") + + with LocalHttpServer( + tmpdir, + max_lifetime=60, + handler=HTTP503Handler, + handler_options={"error_counter": ErrorCounter()}, + ) as server: + http_server = server.url + b = xp.empty_like(a) + with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: + assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) + f.read(b) + + +def test_retry_http_503_fails(tmpdir, xp, capfd): + with LocalHttpServer( + tmpdir, + max_lifetime=60, + handler=HTTP503Handler, + handler_options={"error_counter": ErrorCounter(), "max_error_count": 100}, + ) as server: + a = xp.arange(100, dtype="uint8") + a.tofile(tmpdir / "a") + b = xp.empty_like(a) + + with pytest.raises(RuntimeError) as m, kvikio.defaults.set_http_max_attempts(2): + with kvikio.RemoteFile.open_http(f"{server.url}/a") as f: + f.read(b) + + assert m.match(r"KvikIO: HTTP request reached maximum number of attempts \(2\)") + assert m.match("Got HTTP code 503") + captured = capfd.readouterr() + + records = captured.out.strip().split("\n") + assert len(records) == 1 + assert records[0] == ( + "KvikIO: Got HTTP code 503. Retrying after 500ms (attempt 1 of 2)." + ) + + +def test_no_retries_ok(tmpdir): + a = np.arange(100, dtype="uint8") + a.tofile(tmpdir / "a") + + with LocalHttpServer( + tmpdir, + max_lifetime=60, + ) as server: + http_server = server.url + b = np.empty_like(a) + with kvikio.defaults.set_http_max_attempts(1): + with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: + assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) + f.read(b) + + +def test_set_http_status_code(tmpdir): + with LocalHttpServer( + tmpdir, + max_lifetime=60, + handler=HTTP503Handler, + handler_options={"error_counter": ErrorCounter()}, + ) as server: + http_server = server.url + with kvikio.defaults.set_http_status_codes([429]): + # this raises on the first 503 error, since it's not in the list. + assert kvikio.defaults.http_status_codes() == [429] + with pytest.raises(RuntimeError, match="503"): + with kvikio.RemoteFile.open_http(f"{http_server}/a"): + pass