From 72b52a912148c8db9232cfa42e1fd83b7350289c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Mon, 18 Dec 2023 11:54:52 +0100 Subject: [PATCH] Ugly hack --- src/ert/analysis/configuration.py | 13 ++++++-- .../analysismodulevariablespanel.py | 33 +++++++++++-------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/ert/analysis/configuration.py b/src/ert/analysis/configuration.py index 952e4436364..b0ef6f9d5cf 100644 --- a/src/ert/analysis/configuration.py +++ b/src/ert/analysis/configuration.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union from pydantic import ( BaseModel, @@ -75,7 +75,15 @@ def check_parameters(self) -> Self: return self @field_validator("observations", mode="before") - def check_arguments(cls, observations) -> Any: + def check_arguments( + cls, + observations: Union[ + str, + Dict[str, Union[str, List[int]]], + List[Union[str, List[int]]], + Tuple[Union[str, List[int]]], + ], + ) -> Dict[str, Union[str, List[int]]]: """Because most of the time the users will configure observations as only a name we convert positional arguments to named arguments""" values = [] @@ -85,6 +93,7 @@ def check_arguments(cls, observations) -> Any: elif isinstance(observation, dict): values.append(observation) else: + assert isinstance(observation, (tuple, list)) if len(observation) == 1: values.append({"name": observation[0]}) elif len(observation) == 2: diff --git a/src/ert/gui/ertwidgets/analysismodulevariablespanel.py b/src/ert/gui/ertwidgets/analysismodulevariablespanel.py index 07f3dfd3b5b..cfec0958c18 100644 --- a/src/ert/gui/ertwidgets/analysismodulevariablespanel.py +++ b/src/ert/gui/ertwidgets/analysismodulevariablespanel.py @@ -1,5 +1,6 @@ from functools import partial +from annotated_types import Ge, Gt, Le from qtpy.QtCore import Qt from qtpy.QtWidgets import ( QButtonGroup, @@ -39,16 +40,16 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): if isinstance(analysis_module, IESSettings): for variable_name in ( - name for name in analysis_module.__fields__ if "steplength" in name + name for name in analysis_module.model_fields if "steplength" in name ): - metadata = analysis_module.__fields__[variable_name] + metadata = analysis_module.model_fields[variable_name] layout.addRow( metadata.field_info.title, self.createDoubleSpinBox( metadata.name, analysis_module.__getattribute__(variable_name), - metadata.field_info.ge, - metadata.field_info.le, + [val for val in metadata.metadata if isinstance(val, Ge)][0].ge, + [val for val in metadata.metadata if isinstance(val, Le)][0].le, 0.1, ), ) @@ -74,12 +75,12 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): b.setObjectName("IES_INVERSION_" + str(button_id)) bg.addButton(b, button_id) layout.addRow(b) - metadata = analysis_module.__fields__["enkf_truncation"] + metadata = analysis_module.model_fields["enkf_truncation"] self.truncation_spinner = self.createDoubleSpinBox( - metadata.name, + metadata.title, analysis_module.enkf_truncation, - metadata.field_info.gt + 0.001, - metadata.field_info.le, + [val for val in metadata.metadata if isinstance(val, Gt)][0].gt + 0.001, + [val for val in metadata.metadata if isinstance(val, Le)][0].le, 0.01, ) self.truncation_spinner.setEnabled(False) @@ -96,8 +97,10 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): localization_frame.setLayout(QHBoxLayout()) localization_frame.layout().setContentsMargins(0, 0, 0, 0) - metadata = analysis_module.__fields__["localization_correlation_threshold"] - local_checkbox = QCheckBox(metadata.field_info.title) + metadata = analysis_module.model_fields[ + "localization_correlation_threshold" + ] + local_checkbox = QCheckBox(metadata.title) local_checkbox.setObjectName("localization") local_checkbox.clicked.connect( partial( @@ -108,12 +111,14 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): ) ) - metadata = analysis_module.__fields__["localization_correlation_threshold"] + metadata = analysis_module.model_fields[ + "localization_correlation_threshold" + ] self.local_spinner = self.createDoubleSpinBox( - metadata.name, + metadata.title, analysis_module.correlation_threshold(ensemble_size), - metadata.field_info.ge, - metadata.field_info.le, + [val for val in metadata.metadata if isinstance(val, Ge)][0].ge, + [val for val in metadata.metadata if isinstance(val, Le)][0].le, 0.1, ) self.local_spinner.setObjectName("localization_threshold")