diff --git a/MDANSE/Src/MDANSE/Framework/OutputVariables/IOutputVariable.py b/MDANSE/Src/MDANSE/Framework/OutputVariables/IOutputVariable.py index 712556ba2..42b890576 100644 --- a/MDANSE/Src/MDANSE/Framework/OutputVariables/IOutputVariable.py +++ b/MDANSE/Src/MDANSE/Framework/OutputVariables/IOutputVariable.py @@ -15,13 +15,15 @@ # import collections +from collections.abc import Sequence +from typing import Union, Tuple import numpy as np +import numpy.typing as npt -from MDANSE.Framework.Formats.IFormat import IFormat from MDANSE.Core.Error import Error - from MDANSE.Core.SubclassFactory import SubclassFactory +from MDANSE.Framework.Formats.IFormat import IFormat class OutputVariableError(Error): @@ -50,39 +52,60 @@ class IOutputVariable(np.ndarray, metaclass=SubclassFactory): Those extra attributes will be contain information necessary for the the MDANSE plotter. """ + DEFAULT_AXES = ("x", "y", "z") + def __new__( cls, - value, - varname, - axis="index", - units="unitless", - main_result=False, - partial_result=False, + value: Union[Tuple[int, ...], npt.ArrayLike], + varname: str, + axis: Union[str, Sequence[str], None] = None, + units: str = "unitless", + *, + main_result: bool = False, + partial_result: bool = False, ): + """Instantiate a new MDANSE output variable. + + Parameters + ---------- + value : Union[Tuple[int, ...], npt.ArrayLike] + If value is tuple create empty array with those dimensions. + + If value is ArrayLike interpret as array data. + varname : str + Variable name for reference. + axis : Union[str, Sequence[str], None] + List of axis labels. If str split on ``|``. + units : str + Units of main data. + main_result : bool + Whether the data are the main result of a calculation. + partial_result : bool + Whether the data are a complete calculation. + + Raises + ------ + OutputVariableError + If dimensions of provided data do not align with those of object. """ - Instantiate a new MDANSE output variable. - - @param cls: the class to instantiate. - @type cls: an OutputVariable object - - @param varname: the name of the output variable. - @type varname: string - - @param value: the input numpy array. - @type value: numpy array - - @note: This is the standard implementation for subclassing a numpy array. - Please look at http://docs.scipy.org/doc/numpy/user/basics.subclassing.html for more information. - """ - if isinstance(value, tuple): - value = np.zeros(value, dtype=np.float64) + value = np.empty(value, dtype=np.float64) else: value = np.array(list(value), dtype=np.float64) + if isinstance(axis, str): + axis = tuple(axis.split("|")) + elif axis is None: + axis = cls.DEFAULT_AXES[: value.ndim] + if value.ndim != cls._nDimensions: raise OutputVariableError( - f"Invalid number of dimensions for an output variable of type {cls.name!r}" + f"Invalid number of dimensions ({value.ndim}) for an output variable of type {cls.__name__!r}" + ) + + if len(axis) != cls._nDimensions: + raise OutputVariableError( + f"Invalid number of dimensions ({len(axis)}) for an axis label of type {cls.__name__!r}" ) # Input array is an already formed ndarray instance @@ -94,7 +117,7 @@ def __new__( obj.units = units - obj.axis = axis + obj.axis = "|".join(axis) obj.scaling_factor = 1.0 diff --git a/MDANSE/Tests/UnitTests/test_output_variable.py b/MDANSE/Tests/UnitTests/test_output_variable.py new file mode 100644 index 000000000..47a9b9e5f --- /dev/null +++ b/MDANSE/Tests/UnitTests/test_output_variable.py @@ -0,0 +1,42 @@ +from contextlib import nullcontext as success + +import pytest +from MDANSE.Framework.OutputVariables.IOutputVariable import ( + IOutputVariable, OutputVariableError +) + +DIM_FAIL = pytest.raises(OutputVariableError, match="Invalid number of dimensions") + +@pytest.mark.parametrize("var_type, data_size, extras, expected", [ + ("LineOutputVariable", (), {}, DIM_FAIL), + ("LineOutputVariable", (1,), {}, success()), + ("LineOutputVariable", (1,2,), {}, DIM_FAIL), + ("LineOutputVariable", (1,2,3), {}, DIM_FAIL), + ("LineOutputVariable", (-1,), {}, pytest.raises(ValueError, match="negative dimension")), + + ("SurfaceOutputVariable", (), {}, DIM_FAIL), + ("SurfaceOutputVariable", (1,), {}, DIM_FAIL), + ("SurfaceOutputVariable", (1,2,), {}, success()), + ("SurfaceOutputVariable", (1,2,3), {}, DIM_FAIL), + ("SurfaceOutputVariable", (-1,2), {}, pytest.raises(ValueError, match="negative dimension")), + + ("VolumeOutputVariable", (), {}, DIM_FAIL), + ("VolumeOutputVariable", (1,), {}, DIM_FAIL), + ("VolumeOutputVariable", (1,2,), {}, DIM_FAIL), + ("VolumeOutputVariable", (1,2,3), {}, success()), + ("VolumeOutputVariable", (-1,2,3), {}, pytest.raises(ValueError, match="negative dimension")), + + ("LineOutputVariable", (1,), {"axis": "time"}, success()), + ("LineOutputVariable", (1,), {"axis": ("time",)}, success()), + ("LineOutputVariable", (1,), {"axis": "q|omega"}, DIM_FAIL), + ("LineOutputVariable", (1,), {"axis": ("q", "omega")}, DIM_FAIL), + + ("VolumeOutputVariable", (1,2,3), {"axis": "time"}, DIM_FAIL), + ("VolumeOutputVariable", (1,2,3), {"axis": ("time",)}, DIM_FAIL), + ("VolumeOutputVariable", (1,2,3), {"axis": "q|omega|x"}, success()), + ("VolumeOutputVariable", (1,2,3), {"axis": ("baa", "baa", "baa")}, success()), + +]) +def test_ioutput_variables(var_type, data_size, extras, expected): + with expected: + IOutputVariable.create(var_type, data_size, "test", **extras)