Skip to content

Commit

Permalink
Rework IOutput and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed Mar 4, 2025
1 parent 7e5816b commit bb816e6
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 26 deletions.
75 changes: 49 additions & 26 deletions MDANSE/Src/MDANSE/Framework/OutputVariables/IOutputVariable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
#

import collections
from collections.abc import Sequence
from typing import Union, Tuple

import numpy as np
import numpy.typing as npt

from MDANSE.Framework.Formats.IFormat import IFormat
from MDANSE.Core.Error import Error

from MDANSE.Core.SubclassFactory import SubclassFactory
from MDANSE.Framework.Formats.IFormat import IFormat


class OutputVariableError(Error):
Expand Down Expand Up @@ -50,39 +52,60 @@ class IOutputVariable(np.ndarray, metaclass=SubclassFactory):
Those extra attributes will be contain information necessary for the the MDANSE plotter.
"""

DEFAULT_AXES = ("x", "y", "z")

def __new__(
cls,
value,
varname,
axis="index",
units="unitless",
main_result=False,
partial_result=False,
value: Union[Tuple[int, ...], npt.ArrayLike],
varname: str,
axis: Union[str, Sequence[str], None] = None,
units: str = "unitless",
*,
main_result: bool = False,
partial_result: bool = False,
):
"""Instantiate a new MDANSE output variable.
Parameters
----------
value : Union[Tuple[int, ...], npt.ArrayLike]
If value is tuple create empty array with those dimensions.
If value is ArrayLike interpret as array data.
varname : str
Variable name for reference.
axis : Union[str, Sequence[str], None]
List of axis labels. If str split on ``|``.
units : str
Units of main data.
main_result : bool
Whether the data are the main result of a calculation.
partial_result : bool
Whether the data are a complete calculation.
Raises
------
OutputVariableError
If dimensions of provided data do not align with those of object.
"""
Instantiate a new MDANSE output variable.
@param cls: the class to instantiate.
@type cls: an OutputVariable object
@param varname: the name of the output variable.
@type varname: string
@param value: the input numpy array.
@type value: numpy array
@note: This is the standard implementation for subclassing a numpy array.
Please look at http://docs.scipy.org/doc/numpy/user/basics.subclassing.html for more information.
"""

if isinstance(value, tuple):
value = np.zeros(value, dtype=np.float64)
value = np.empty(value, dtype=np.float64)
else:
value = np.array(list(value), dtype=np.float64)

if isinstance(axis, str):
axis = tuple(axis.split("|"))
elif axis is None:
axis = cls.DEFAULT_AXES[: value.ndim]

if value.ndim != cls._nDimensions:
raise OutputVariableError(
f"Invalid number of dimensions for an output variable of type {cls.name!r}"
f"Invalid number of dimensions ({value.ndim}) for an output variable of type {cls.__name__!r}"
)

if len(axis) != cls._nDimensions:
raise OutputVariableError(
f"Invalid number of dimensions ({len(axis)}) for an axis label of type {cls.__name__!r}"
)

# Input array is an already formed ndarray instance
Expand All @@ -94,7 +117,7 @@ def __new__(

obj.units = units

obj.axis = axis
obj.axis = "|".join(axis)

obj.scaling_factor = 1.0

Expand Down
42 changes: 42 additions & 0 deletions MDANSE/Tests/UnitTests/test_output_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from contextlib import nullcontext as success

import pytest
from MDANSE.Framework.OutputVariables.IOutputVariable import (
IOutputVariable, OutputVariableError
)

DIM_FAIL = pytest.raises(OutputVariableError, match="Invalid number of dimensions")

@pytest.mark.parametrize("var_type, data_size, extras, expected", [
("LineOutputVariable", (), {}, DIM_FAIL),
("LineOutputVariable", (1,), {}, success()),
("LineOutputVariable", (1,2,), {}, DIM_FAIL),
("LineOutputVariable", (1,2,3), {}, DIM_FAIL),
("LineOutputVariable", (-1,), {}, pytest.raises(ValueError, match="negative dimension")),
("SurfaceOutputVariable", (), {}, DIM_FAIL),
("SurfaceOutputVariable", (1,), {}, DIM_FAIL),
("SurfaceOutputVariable", (1,2,), {}, success()),
("SurfaceOutputVariable", (1,2,3), {}, DIM_FAIL),
("SurfaceOutputVariable", (-1,2), {}, pytest.raises(ValueError, match="negative dimension")),
("VolumeOutputVariable", (), {}, DIM_FAIL),
("VolumeOutputVariable", (1,), {}, DIM_FAIL),
("VolumeOutputVariable", (1,2,), {}, DIM_FAIL),
("VolumeOutputVariable", (1,2,3), {}, success()),
("VolumeOutputVariable", (-1,2,3), {}, pytest.raises(ValueError, match="negative dimension")),
("LineOutputVariable", (1,), {"axis": "time"}, success()),
("LineOutputVariable", (1,), {"axis": ("time",)}, success()),
("LineOutputVariable", (1,), {"axis": "q|omega"}, DIM_FAIL),
("LineOutputVariable", (1,), {"axis": ("q", "omega")}, DIM_FAIL),
("VolumeOutputVariable", (1,2,3), {"axis": "time"}, DIM_FAIL),
("VolumeOutputVariable", (1,2,3), {"axis": ("time",)}, DIM_FAIL),
("VolumeOutputVariable", (1,2,3), {"axis": "q|omega|x"}, success()),
("VolumeOutputVariable", (1,2,3), {"axis": ("baa", "baa", "baa")}, success()),
])
def test_ioutput_variables(var_type, data_size, extras, expected):
with expected:
IOutputVariable.create(var_type, data_size, "test", **extras)

0 comments on commit bb816e6

Please sign in to comment.