From 2ef47e426be201e84983a14d1c0f3d3362020717 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 3 Feb 2025 14:56:28 +0000 Subject: [PATCH] Throw 500 errors on HEAD too --- python/kvikio/tests/test_http_io.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/python/kvikio/tests/test_http_io.py b/python/kvikio/tests/test_http_io.py index 61ed14b377..b084f81bcf 100644 --- a/python/kvikio/tests/test_http_io.py +++ b/python/kvikio/tests/test_http_io.py @@ -4,6 +4,7 @@ import http from http.server import SimpleHTTPRequestHandler +from typing import Literal import numpy as np import pytest @@ -53,16 +54,24 @@ def __init__( self.error_counter = error_counter super().__init__(*args, directory=directory, **kwargs) - def do_GET(self): + def _do_with_error_count(self, method: Literal["GET", "HEAD"]) -> None: if self.error_counter.value < self.max_error_count: self.error_counter.value += 1 self.send_error(http.HTTPStatus.SERVICE_UNAVAILABLE) - self.send_header("CurrentErrorCount", self.error_counter.value) - self.send_header("MaxErrorCount", self.max_error_count) + self.send_header("CurrentErrorCount", str(self.error_counter.value)) + self.send_header("MaxErrorCount", str(self.max_error_count)) return None else: - self.error_counter.value += 1 - super().do_GET() + if method == "GET": + return super().do_GET() + else: + return super().do_HEAD() + + def do_GET(self) -> None: + return self._do_with_error_count("GET") + + def do_HEAD(self) -> None: + return self._do_with_error_count("HEAD") @pytest.fixture @@ -193,6 +202,7 @@ def test_set_http_status_code(tmpdir, xp): http_server = server.url with kvikio.defaults.set_http_status_codes([429]): # this raises on the first 503 error, since it's not in the list. + assert kvikio.defaults.http_status_codes() == [429] with pytest.raises(RuntimeError, match="503"): with kvikio.RemoteFile.open_http(f"{http_server}/a"): pass