Skip to content

Commit

Permalink
Ugly hack
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Dec 18, 2023
1 parent e6c2504 commit 72b52a9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
13 changes: 11 additions & 2 deletions src/ert/analysis/configuration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -75,7 +75,15 @@ def check_parameters(self) -> Self:
return self

@field_validator("observations", mode="before")
def check_arguments(cls, observations) -> Any:
def check_arguments(
cls,
observations: Union[
str,
Dict[str, Union[str, List[int]]],
List[Union[str, List[int]]],
Tuple[Union[str, List[int]]],
],
) -> Dict[str, Union[str, List[int]]]:
"""Because most of the time the users will configure observations as only a name
we convert positional arguments to named arguments"""
values = []
Expand All @@ -85,6 +93,7 @@ def check_arguments(cls, observations) -> Any:
elif isinstance(observation, dict):
values.append(observation)
else:
assert isinstance(observation, (tuple, list))
if len(observation) == 1:
values.append({"name": observation[0]})
elif len(observation) == 2:
Expand Down
33 changes: 19 additions & 14 deletions src/ert/gui/ertwidgets/analysismodulevariablespanel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial

from annotated_types import Ge, Gt, Le
from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QButtonGroup,
Expand Down Expand Up @@ -39,16 +40,16 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int):

if isinstance(analysis_module, IESSettings):
for variable_name in (
name for name in analysis_module.__fields__ if "steplength" in name
name for name in analysis_module.model_fields if "steplength" in name
):
metadata = analysis_module.__fields__[variable_name]
metadata = analysis_module.model_fields[variable_name]
layout.addRow(
metadata.field_info.title,
self.createDoubleSpinBox(
metadata.name,
analysis_module.__getattribute__(variable_name),
metadata.field_info.ge,
metadata.field_info.le,
[val for val in metadata.metadata if isinstance(val, Ge)][0].ge,
[val for val in metadata.metadata if isinstance(val, Le)][0].le,
0.1,
),
)
Expand All @@ -74,12 +75,12 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int):
b.setObjectName("IES_INVERSION_" + str(button_id))
bg.addButton(b, button_id)
layout.addRow(b)
metadata = analysis_module.__fields__["enkf_truncation"]
metadata = analysis_module.model_fields["enkf_truncation"]
self.truncation_spinner = self.createDoubleSpinBox(
metadata.name,
metadata.title,
analysis_module.enkf_truncation,
metadata.field_info.gt + 0.001,
metadata.field_info.le,
[val for val in metadata.metadata if isinstance(val, Gt)][0].gt + 0.001,
[val for val in metadata.metadata if isinstance(val, Le)][0].le,
0.01,
)
self.truncation_spinner.setEnabled(False)
Expand All @@ -96,8 +97,10 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int):
localization_frame.setLayout(QHBoxLayout())
localization_frame.layout().setContentsMargins(0, 0, 0, 0)

metadata = analysis_module.__fields__["localization_correlation_threshold"]
local_checkbox = QCheckBox(metadata.field_info.title)
metadata = analysis_module.model_fields[
"localization_correlation_threshold"
]
local_checkbox = QCheckBox(metadata.title)
local_checkbox.setObjectName("localization")
local_checkbox.clicked.connect(
partial(
Expand All @@ -108,12 +111,14 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int):
)
)

metadata = analysis_module.__fields__["localization_correlation_threshold"]
metadata = analysis_module.model_fields[
"localization_correlation_threshold"
]
self.local_spinner = self.createDoubleSpinBox(
metadata.name,
metadata.title,
analysis_module.correlation_threshold(ensemble_size),
metadata.field_info.ge,
metadata.field_info.le,
[val for val in metadata.metadata if isinstance(val, Ge)][0].ge,
[val for val in metadata.metadata if isinstance(val, Le)][0].le,
0.1,
)
self.local_spinner.setObjectName("localization_threshold")
Expand Down

0 comments on commit 72b52a9

Please sign in to comment.