From 00c14805b229f3b2fe3e917c2d99064e25d2fc29 Mon Sep 17 00:00:00 2001 From: morrisnein Date: Thu, 14 Dec 2023 12:12:45 +0300 Subject: [PATCH] fix typing --- .../dataset_models_fitness_scaler.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/meta_automl/data_preparation/model_fitness_scalers/dataset_models_fitness_scaler.py b/meta_automl/data_preparation/model_fitness_scalers/dataset_models_fitness_scaler.py index 48400ca4..591b94d1 100644 --- a/meta_automl/data_preparation/model_fitness_scalers/dataset_models_fitness_scaler.py +++ b/meta_automl/data_preparation/model_fitness_scalers/dataset_models_fitness_scaler.py @@ -1,18 +1,26 @@ from copy import copy -from typing import Any, Dict, Sequence +from typing import Dict, Protocol, Sequence, Type +import numpy as np from sklearn.preprocessing import MinMaxScaler from typing_extensions import Self -from meta_automl.data_preparation.dataset import DatasetIDType from meta_automl.data_preparation.dataset.dataset_base import DatasetType_co from meta_automl.data_preparation.evaluated_model import EvaluatedModel +class ScalerType(Protocol): + def fit(self, x) -> Self: pass + + def transform(self, x) -> np.ndarray: pass + + def fit_transform(self, x) -> np.ndarray: pass + + class DatasetModelsFitnessScaler: - def __init__(self, scaler_class=MinMaxScaler): + def __init__(self, scaler_class: Type[ScalerType] = MinMaxScaler): self.scaler_class = scaler_class - self.scalers: Dict[DatasetIDType, Any] = {} + self.scalers: Dict[str, ScalerType] = {} def fit(self, models: Sequence[Sequence[EvaluatedModel]], datasets: Sequence[DatasetType_co]) -> Self: dataset_representations = map(repr, datasets)