diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index dcfd6d8a1..e5e8f9d67 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -7,6 +7,7 @@ from typing import NewType from typing import Optional from typing import Sequence +from typing import Tuple from typing import Union import numpy as np @@ -167,6 +168,11 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): While the functions in the numpy namespace will work on ``AxesArray`` objects, the documentation must be found in their equivalent names here. + Current array function implementations: + * ``np.concatenate`` + * ``np.reshape`` + * ``np.transpose`` + Parameters: input_array: the data to create the array. axes: A dictionary of axis labels to shape indices. Axes labels must @@ -421,6 +427,25 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): return AxesArray(out, axes=new_axes) +@implements(np.transpose) +def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None): + """Returns an array with axes transposed. + + Args: + a: input array + axes: As the numpy function + """ + out = np.transpose(np.asarray(a), axes) + if axes is None: + axes = range(a.ndim)[::-1] + new_axes = {} + old_reverse = a._ax_map.reverse_map + for new_ind, old_ind in enumerate(axes): + _compat_axes_append(new_axes, old_reverse[old_ind], new_ind) + + return AxesArray(out, new_axes) + + def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 7f19596c2..2e6b127de 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -491,3 +491,17 @@ def test_strip_ellipsis(): result = axes._expand_indexer_ellipsis(key, 1) expected = [1] assert result == expected + + +def test_transpose(): + axes = {"ax_a": 0, "ax_b": [1, 2]} + arr = AxesArray(np.arange(8).reshape(2, 2, 2), axes) + tp = np.transpose(arr, [2, 0, 1]) + result = tp.axes + expected = {"ax_a": 1, "ax_b": [0, 2]} + assert result == expected + assert_array_equal(tp, np.transpose(np.asarray(arr), [2, 0, 1])) + arr = arr[..., 0] + tp = arr.T + expected = {"ax_a": 1, "ax_b": 0} + assert_array_equal(tp, np.asarray(arr).T)