From a54897ea2f338aae8ca3c9555c9b721c29b90107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Berland?= Date: Fri, 5 Jan 2024 15:53:47 +0100 Subject: [PATCH] Replace configsuite with pydantic in interp_relperm Remove configsuite from subscripts dependencies as interp_relperm was the last user of it. --- pyproject.toml | 1 - .../interp_relperm/interp_relperm.py | 228 +++++------------- tests/test_interp_relperm.py | 126 ++++------ 3 files changed, 99 insertions(+), 256 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4b906c81f..2e5c7087a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "configsuite", "resdata", "res2df", "ert>=2.38.0b7", diff --git a/src/subscript/interp_relperm/interp_relperm.py b/src/subscript/interp_relperm/interp_relperm.py index 1f87b6c46..f95f10bfe 100755 --- a/src/subscript/interp_relperm/interp_relperm.py +++ b/src/subscript/interp_relperm/interp_relperm.py @@ -1,17 +1,18 @@ +from __future__ import annotations + import argparse import logging import os import sys from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional -import configsuite # lgtm [py/import-and-import-from] import pandas as pd import pyscal import yaml -from configsuite import MetaKeys as MK # lgtm [py/import-and-import-from] -from configsuite import types # lgtm [py/import-and-import-from] +from pydantic import BaseModel, Field, FilePath, model_validator from res2df import satfunc +from typing_extensions import Annotated import subscript @@ -101,146 +102,43 @@ """ # noqa -@configsuite.validator_msg("Valid file name") -def _is_filename(filename: str): - return Path(filename).exists() - - -@configsuite.validator_msg("Valid interpolator list") -def _is_valid_interpolator_list(interpolators: list): - if len(interpolators) > 0: - return True - return False - - -@configsuite.validator_msg("Valid interpolator") -def _is_valid_interpolator(interp: dict): - valid = False - - try: - if interp["param_w"]: - valid = True - elif interp["param_w"] == 0: - valid = True - - except (KeyError, ValueError, TypeError): - pass - - try: - if interp["param_w"] > 1.0 or interp["param_w"] < -1.0: - valid = False - except (KeyError, ValueError, TypeError): - pass - - try: - if interp["param_g"]: - valid = True - elif interp["param_g"] == 0: - valid = True - except (KeyError, ValueError, TypeError): - pass - - try: - if interp["param_g"] > 1.0 or interp["param_g"] < -1.0: - valid = False - except (KeyError, ValueError, TypeError): - pass - - return valid - - -@configsuite.validator_msg("Low, base and high are provided") -def _is_valid_table_entries(schema: dict): - if "base" in schema and "low" in schema and "high" in schema: - if schema["low"] and schema["base"] and schema["high"]: - return ( - isinstance(schema["low"], tuple) - and isinstance(schema["base"], tuple) - and isinstance(schema["high"], tuple) - ) - if "pyscalfile" in schema: - # If pyscalfile is given, we don't need low/base/high - return True - return False - - -@configsuite.validator_msg("Valid Eclipse keyword family") -def _is_valid_eclipse_keyword_family(familychoice: int): - return familychoice in [1, 2] - - -def get_cfg_schema() -> dict: - """ - Defines the yml config schema - """ - schema = { - MK.Type: types.NamedDict, - MK.ElementValidators: (_is_valid_table_entries,), - MK.Content: { - "base": { - MK.Type: types.List, - MK.Content: { - MK.Item: { - MK.Type: types.String, - MK.ElementValidators: (_is_filename,), - } - }, - }, - "low": { - MK.Type: types.List, - MK.Content: { - MK.Item: { - MK.Type: types.String, - MK.ElementValidators: (_is_filename,), - } - }, - }, - "high": { - MK.Type: types.List, - MK.Content: { - MK.Item: { - MK.Type: types.String, - MK.ElementValidators: (_is_filename,), - } - }, - }, - "pyscalfile": { - MK.Type: types.String, - MK.ElementValidators: (_is_filename,), - MK.AllowNone: True, - }, - "result_file": {MK.Type: types.String}, - "family": { - MK.Type: types.Number, - MK.Default: 1, - MK.ElementValidators: (_is_valid_eclipse_keyword_family,), - }, - "delta_s": {MK.Type: types.Number, MK.Default: 0.01}, - "interpolations": { - MK.Type: types.List, - MK.ElementValidators: (_is_valid_interpolator_list,), - MK.Content: { - MK.Item: { - MK.Type: types.NamedDict, - MK.ElementValidators: (_is_valid_interpolator,), - MK.Content: { - "tables": { - MK.Type: types.List, - MK.Content: {MK.Item: {MK.Type: types.Integer}}, - }, - "param_w": {MK.Type: types.Number, MK.AllowNone: True}, - "param_g": {MK.Type: types.Number, MK.AllowNone: True}, - }, - } - }, - }, - }, - } - - return schema - - -def parse_satfunc_files(filenames: List[str]) -> pd.DataFrame: +class Interpolator(BaseModel): + tables: Optional[List[int]] = [] + param_w: Optional[Annotated[float, Field(strict=True, ge=-1, le=1)]] = None + param_g: Optional[Annotated[float, Field(strict=True, ge=-1, le=1)]] = None + + @model_validator(mode="after") + def check_param_w_or_param_g(self) -> Interpolator: + assert ( + self.param_w is not None or self.param_g is not None + ), "Provide either param_w or param_g" + return self + + +class InterpRelpermConfig(BaseModel): + low: Optional[List[FilePath]] = None + base: Optional[List[FilePath]] = None + high: Optional[List[FilePath]] = None + pyscalfile: Optional[FilePath] = None + result_file: str + family: Literal[1, 2] = 1 + delta_s: float = 0.01 + interpolations: Annotated[List[Interpolator], Field(..., min_length=1)] + + @model_validator(mode="after") + def check_lowbasehigh_or_pyscalfile(self) -> InterpRelpermConfig: + if self.pyscalfile is None: + assert self.base is not None, "base is not provided" + assert self.high is not None, "high is not provided" + assert self.low is not None, "low is not provided" + else: + assert self.base is None, "do not specify base when pyscalfile is set" + assert self.high is None, "do not specify high when pyscalfile is set" + assert self.low is None, "do not specify low when pyscalfile is set" + return self + + +def parse_satfunc_files(filenames: List[Path]) -> pd.DataFrame: """ Routine to gather scal tables (SWOF and SGOF) from ecl include files. @@ -251,7 +149,6 @@ def parse_satfunc_files(filenames: List[str]) -> pd.DataFrame: Returns: dataframe with the tables """ - return pd.concat( [ satfunc.df(Path(filename).read_text(encoding="utf8")) @@ -457,35 +354,23 @@ def process_config(cfg: Dict[str, Any], root_path: Optional[Path] = None) -> Non if root_path is not None: cfg = prepend_root_path_to_relative_files(cfg, root_path) - cfg_schema = get_cfg_schema() - cfg_suite = configsuite.ConfigSuite(cfg, cfg_schema, deduce_required=True) - - if not cfg_suite.valid: - logger.error("Sorry, the configuration is invalid.") - sys.exit(cfg_suite.errors) - - # set default values - relperm_delta_s = False - if cfg_suite.snapshot.delta_s: - relperm_delta_s = cfg_suite.snapshot.delta_s + config = InterpRelpermConfig(**cfg) base_df: pd.DataFrame = pd.DataFrame() low_df: pd.DataFrame = pd.DataFrame() high_df: pd.DataFrame = pd.DataFrame() - if cfg_suite.snapshot.pyscalfile is not None: - if cfg_suite.snapshot.base or cfg_suite.snapshot.low or cfg_suite.snapshot.high: + if config.pyscalfile is not None: + if config.base or config.low or config.high: logger.error( "Inconsistent configuration. " "You cannot define both pyscalfile and base/low/high" ) sys.exit(1) - logger.info( - "Loading relperm parametrization from %s", cfg_suite.snapshot.pyscalfile - ) + logger.info("Loading relperm parametrization from %s", config.pyscalfile) param_dframe = pyscal.PyscalFactory.load_relperm_df( - cfg_suite.snapshot.pyscalfile + config.pyscalfile ).set_index("CASE") base_df = ( pyscal.PyscalFactory.create_pyscal_list(param_dframe.loc["base"]) @@ -504,9 +389,12 @@ def process_config(cfg: Dict[str, Any], root_path: Optional[Path] = None) -> Non ) else: # Parse tables from files - base_df = parse_satfunc_files(cfg_suite.snapshot.base) - low_df = parse_satfunc_files(cfg_suite.snapshot.low) - high_df = parse_satfunc_files(cfg_suite.snapshot.high) + assert config.base is not None + base_df = parse_satfunc_files(config.base) + assert config.low is not None + low_df = parse_satfunc_files(config.low) + assert config.high is not None + high_df = parse_satfunc_files(config.high) if not ( set(base_df.columns) == set(low_df.columns) @@ -523,7 +411,7 @@ def process_config(cfg: Dict[str, Any], root_path: Optional[Path] = None) -> Non satnums = range(1, base_df.reset_index("SATNUM")["SATNUM"].unique().max() + 1) for satnum in satnums: interp_values = {"param_w": 0.0, "param_g": 0.0} - for interp in cfg_suite.snapshot.interpolations: + for interp in config.interpolations: if not interp.tables or satnum in interp.tables: if interp.param_w: interp_values["param_w"] = interp.param_w @@ -537,17 +425,17 @@ def process_config(cfg: Dict[str, Any], root_path: Optional[Path] = None) -> Non high_df.loc[satnum], interp_values, satnum, - relperm_delta_s, + config.delta_s, ) ) - Path(cfg_suite.snapshot.result_file).write_text( - interpolants.build_eclipse_data(cfg_suite.snapshot.family), encoding="utf-8" + Path(config.result_file).write_text( + interpolants.build_eclipse_data(config.family), encoding="utf-8" ) logger.info( "Done; interpolated relperm curves written to file: %s", - str(cfg_suite.snapshot.result_file), + str(config.result_file), ) diff --git a/tests/test_interp_relperm.py b/tests/test_interp_relperm.py index 8e87a63f9..775a6b088 100644 --- a/tests/test_interp_relperm.py +++ b/tests/test_interp_relperm.py @@ -2,15 +2,16 @@ import subprocess from pathlib import Path -import configsuite import pandas as pd import pytest import yaml +from pydantic import ValidationError from pyscal import PyscalFactory from pyscal.utils.testing import sat_table_str_ok from res2df import satfunc from subscript.interp_relperm import interp_relperm +from subscript.interp_relperm.interp_relperm import InterpRelpermConfig TESTDATA = Path(__file__).absolute().parent / "testdata_interp_relperm" @@ -40,76 +41,49 @@ ).set_index("CASE") -def test_get_cfg_schema(): - """Test the configsuite schema""" - cfg_filen = TESTDATA / "cfg.yml" - - cfg = yaml.safe_load(cfg_filen.read_text(encoding="utf8")) - - # add root-path to all include files - if "base" in cfg.keys(): - for idx in range(len(cfg["base"])): - cfg["base"][idx] = str(TESTDATA / cfg["base"][idx]) - if "high" in cfg.keys(): - for idx in range(len(cfg["high"])): - cfg["high"][idx] = str(TESTDATA / cfg["high"][idx]) - if "low" in cfg.keys(): - for idx in range(len(cfg["low"])): - cfg["low"][idx] = str(TESTDATA / cfg["low"][idx]) - - schema = interp_relperm.get_cfg_schema() - suite = configsuite.ConfigSuite(cfg, schema, deduce_required=True) - - assert suite.valid - - def test_prepend_root_path(): """Test that we need to prepend with root-path""" cfg_filen = TESTDATA / "cfg.yml" cfg = yaml.safe_load(cfg_filen.read_text(encoding="utf8")) - schema = interp_relperm.get_cfg_schema() - suite_no_rootpath = configsuite.ConfigSuite(cfg, schema, deduce_required=True) - assert not suite_no_rootpath.valid + with pytest.raises(ValidationError): + InterpRelpermConfig(**cfg) cfg_with_rootpath = interp_relperm.prepend_root_path_to_relative_files( cfg, TESTDATA ) - suite = configsuite.ConfigSuite(cfg_with_rootpath, schema, deduce_required=True) - assert suite.valid + InterpRelpermConfig(**cfg_with_rootpath) # When root-path is prepended (with an absolute part) it should not # matter if we reapply: cfg_with_double_rootpath = interp_relperm.prepend_root_path_to_relative_files( cfg_with_rootpath, TESTDATA ) - suite_double = configsuite.ConfigSuite( - cfg_with_double_rootpath, schema, deduce_required=True - ) - assert suite_double.valid + InterpRelpermConfig(**cfg_with_double_rootpath) @pytest.mark.parametrize( "dictupdates, expected_error", [ - ({"base": ["foo.inc"]}, "Valid file name"), - ({}, "Valid interpolator list"), - ({"interpolations": [{"tables": []}]}, "Valid interpolator"), - ({"interpolations": [{"param_w": 1.5}]}, "Valid interpolator"), - ({"interpolations": [{"param_w": -1.1}]}, "Valid interpolator"), + ({"base": ["foo.inc"]}, "Path does not point to a file"), + ({}, "Field required"), + ({"interpolations": [{"tables": []}]}, "Provide either param_w or param_g"), + ({"interpolations": [{"param_w": 1.5}]}, "Input should be less than"), + ({"interpolations": [{"param_w": -1.1}]}, "Input should be greater than"), ({"interpolations": [{"param_w": 0}]}, None), ({"interpolations": [{"param_g": 0}]}, None), ({"interpolations": [{"param_w": 0, "param_g": 0}]}, None), ({"interpolations": [{"param_w": 0.1, "param_g": -0.1}]}, None), - ({"interpolations": [{"param_g": -1.5}]}, "Valid interpolator"), - ({"interpolations": [{"param_g": 1.5}]}, "Valid interpolator"), - ({"interpolations": [{"param_w": "weird"}]}, "Is x a number"), - ({"interpolations": [{"param_g": "Null"}]}, "Is x a number"), + ({"interpolations": [{"param_g": -1.5}]}, "Input should be greater than"), + ({"interpolations": [{"param_g": 1.5}]}, "Input should be less than"), + ({"interpolations": [{"param_w": "weird"}]}, "Input should be a valid number"), + ({"interpolations": [{"param_g": "Null"}]}, "Input should be a valid number"), ], ) -def test_schema_errors(dictupdates, expected_error): - """Test that configsuite errors correctly with some hint to the resolution""" +def test_config_errors(dictupdates, expected_error): + """Test that the pydantic model errors correctly with some hint to the + resolution""" os.chdir(TESTDATA) cfg = { "base": ["swof_base.inc", "sgof_base.inc"], @@ -118,19 +92,17 @@ def test_schema_errors(dictupdates, expected_error): "result_file": "foo.inc", } cfg.update(dictupdates) - parsed_cfg = configsuite.ConfigSuite( - cfg, interp_relperm.get_cfg_schema(), deduce_required=True - ) + if expected_error is not None: - assert not parsed_cfg.valid - assert expected_error in str(parsed_cfg.errors) + with pytest.raises(ValidationError) as validation_error: + InterpRelpermConfig(**cfg) + assert expected_error in str(validation_error) else: - print(parsed_cfg.errors) - assert parsed_cfg.valid + InterpRelpermConfig(**cfg) def test_schema_errors_low_base_high(): - """Test for detection of schema errors related to low/base/high""" + """Test for detection of config errors related to low/base/high""" os.chdir(TESTDATA) cfg = { "base": ["swof_base.inc", "sgof_base.inc"], @@ -139,42 +111,32 @@ def test_schema_errors_low_base_high(): "result_file": "foo.inc", "interpolations": [{"param_w": 0.1, "param_g": -0.1}], } - parsed_cfg = configsuite.ConfigSuite( - cfg, interp_relperm.get_cfg_schema(), deduce_required=True - ) - assert parsed_cfg.valid + InterpRelpermConfig(**cfg) cfg_no_low = cfg.copy() del cfg_no_low["low"] - parsed_cfg = configsuite.ConfigSuite( - cfg_no_low, interp_relperm.get_cfg_schema(), deduce_required=True - ) - assert not parsed_cfg.valid - assert "Low, base and high are provided is false" in str(parsed_cfg.errors) + with pytest.raises(ValidationError) as validation_error: + InterpRelpermConfig(**cfg_no_low) + assert "low is not provided" in str(validation_error) cfg_no_high = cfg.copy() del cfg_no_high["high"] - parsed_cfg = configsuite.ConfigSuite( - cfg_no_high, interp_relperm.get_cfg_schema(), deduce_required=True - ) - assert not parsed_cfg.valid - assert "Low, base and high are provided is false" in str(parsed_cfg.errors) + with pytest.raises(ValidationError) as validation_error: + InterpRelpermConfig(**cfg_no_high) + assert "high is not provided" in str(validation_error) cfg_no_base = cfg.copy() del cfg_no_base["base"] - parsed_cfg = configsuite.ConfigSuite( - cfg_no_base, interp_relperm.get_cfg_schema(), deduce_required=True - ) - assert not parsed_cfg.valid - assert "Low, base and high are provided is false" in str(parsed_cfg.errors) + with pytest.raises(ValidationError) as validation_error: + InterpRelpermConfig(**cfg_no_base) + assert "base is not provided" in str(validation_error) cfg_string_for_high = cfg.copy() cfg_string_for_high["high"] = "sgof_opt.inc" - parsed_cfg = configsuite.ConfigSuite( - cfg_string_for_high, interp_relperm.get_cfg_schema(), deduce_required=True - ) - assert not parsed_cfg.valid - assert "Is x a list is false on input 'sgof_opt.inc'" in str(parsed_cfg.errors) + with pytest.raises(ValidationError) as validation_error: + InterpRelpermConfig(**cfg_string_for_high) + assert "Input should be a valid list" in str(validation_error) + assert "sgof_opt.inc" in str(validation_error) def test_garbled_base_input(tmp_path): @@ -190,11 +152,6 @@ def test_garbled_base_input(tmp_path): "result_file": str(tmp_path / "foo.inc"), "interpolations": [{"param_w": 0.1, "param_g": -0.1}], } - parsed_cfg = configsuite.ConfigSuite( - cfg, interp_relperm.get_cfg_schema(), deduce_required=True - ) - assert parsed_cfg.valid # Error can't be captured by schema - with pytest.raises(SystemExit): interp_relperm.process_config(cfg) @@ -375,12 +332,11 @@ def test_wrong_family(tmp_path): "family": "Rockefeller", "delta_s": 0.1, } - with pytest.raises( - SystemExit, match="Is x a number is false on input 'Rockefeller'" - ): + with pytest.raises(ValidationError, match="Input should be 1 or 2"): interp_relperm.process_config(config) + config["family"] = 3 - with pytest.raises(SystemExit): + with pytest.raises(ValidationError): interp_relperm.process_config(config)