Skip to content

Commit e6edd5d

Browse files
committed
lint: solve mypy linting errors
1 parent 6cc73c9 commit e6edd5d

File tree

7 files changed

+25
-16
lines changed

7 files changed

+25
-16
lines changed

black_it/calibrator.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,11 @@ def __init__( # noqa: PLR0913
127127
# initialize arrays
128128
self.params_samp = np.zeros((0, self.param_grid.dims))
129129
self.losses_samp = np.zeros(0)
130-
self.batch_num_samp = np.zeros(0, dtype=int)
131-
self.method_samp = np.zeros(0, dtype=int)
132-
self.series_samp = np.zeros((0, self.ensemble_size, self.N, self.D))
130+
self.batch_num_samp: NDArray[np.int64] = np.zeros(0, dtype=int)
131+
self.method_samp: NDArray[np.int64] = np.zeros(0, dtype=int)
132+
self.series_samp: NDArray[np.float64] = np.zeros(
133+
(0, self.ensemble_size, self.N, self.D),
134+
)
133135

134136
# initialize variables before calibration
135137
self.n_sampled_params = 0

black_it/loss_functions/gsl_div.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def get_words(time_series: NDArray[np.float64], length: int) -> NDArray:
270270
"the chosen word length is too high",
271271
exception_class=ValueError,
272272
)
273-
tsw = np.zeros(shape=(tswlen,), dtype=np.int32)
273+
tsw: NDArray[np.float64] = np.zeros(shape=(tswlen,), dtype=np.int32)
274274

275275
for i in range(length):
276276
k = 10 ** (length - i - 1)

black_it/loss_functions/likelihood.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import warnings
21-
from typing import TYPE_CHECKING, Callable
21+
from typing import TYPE_CHECKING, Callable, cast
2222

2323
import numpy as np
2424

@@ -82,9 +82,13 @@ def compute_loss(
8282
Returns:
8383
The loss value.
8484
"""
85-
r = sim_data_ensemble.shape[0] # number of repetitions
86-
s = sim_data_ensemble.shape[1] # simulation length
87-
d = sim_data_ensemble.shape[2] # number of dimensions
85+
sim_data_ensemble_shape: tuple[int, int, int] = cast(
86+
tuple[int, int, int],
87+
sim_data_ensemble.shape,
88+
)
89+
r = sim_data_ensemble_shape[0] # number of repetitions
90+
s = sim_data_ensemble_shape[1] # simulation length
91+
d = sim_data_ensemble_shape[2] # time series dimension
8892

8993
if self.coordinate_weights is not None:
9094
warnings.warn( # noqa: B028

black_it/plot/plot_results.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
if TYPE_CHECKING:
3333
import os
3434

35+
from numpy.typing import NDArray
36+
3537

3638
def _get_samplers_id_table(saving_folder: str | os.PathLike) -> dict[str, int]:
3739
"""Get the id table of the samplers from the checkpoint.
@@ -298,7 +300,7 @@ def plot_sampling_interact(saving_folder: str | os.PathLike) -> None:
298300
data_frame = pd.read_csv(calibration_results_file)
299301

300302
max_bn = int(max(data_frame["batch_num_samp"]))
301-
all_bns = np.arange(max_bn + 1, dtype=int)
303+
all_bns: NDArray[np.int64] = np.arange(max_bn + 1, dtype=int)
302304
indices_bns = np.array_split(all_bns, min(max_bn, 3))
303305

304306
dict_bns = {}

black_it/samplers/xgboost.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
if TYPE_CHECKING:
2929
from numpy.typing import NDArray
3030

31-
MAX_FLOAT32 = np.finfo(np.float32).max
32-
MIN_FLOAT32 = np.finfo(np.float32).min
33-
EPS_FLOAT32 = np.finfo(np.float32).eps
31+
MAX_FLOAT32: float = cast(float, np.finfo(np.float32).max)
32+
MIN_FLOAT32: float = cast(float, np.finfo(np.float32).min)
33+
EPS_FLOAT32: float = cast(float, np.finfo(np.float32).eps)
3434

3535

3636
class XGBoostSampler(MLSurrogateSampler):

black_it/search_space.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
self._param_grid: list[NDArray[np.float64]] = []
7373
self._space_size = 1
7474
for i in range(self.dims):
75-
new_col = np.arange(
75+
new_col: NDArray[np.float64] = np.arange(
7676
parameters_bounds[0][i],
7777
parameters_bounds[1][i] + 0.0000001,
7878
parameters_precision[i],

tests/test_samplers/test_xgboost.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# along with this program. If not, see <http://www.gnu.org/licenses/>.
1616
"""This module contains tests for the xgboost sampler."""
1717
import sys
18+
from typing import cast
1819

1920
import numpy as np
2021

@@ -34,9 +35,9 @@
3435
else:
3536
expected_params = np.array([[0.24, 0.26], [0.37, 0.21], [0.43, 0.14], [0.11, 0.04]])
3637

37-
MAX_FLOAT32 = np.finfo(np.float32).max
38-
MIN_FLOAT32 = np.finfo(np.float32).min
39-
EPS_FLOAT32 = np.finfo(np.float32).eps
38+
MAX_FLOAT32: float = cast(float, np.finfo(np.float32).max)
39+
MIN_FLOAT32: float = cast(float, np.finfo(np.float32).min)
40+
EPS_FLOAT32: float = cast(float, np.finfo(np.float32).eps)
4041

4142

4243
def test_xgboost_2d() -> None:

0 commit comments

Comments
 (0)