Skip to content

Commit

Permalink
TST: Test _determine_adv_broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 4, 2024
1 parent efbfac2 commit c101a5a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
11 changes: 6 additions & 5 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
from typing import Union

import numpy as np
from numpy.typing import NDArray
from sklearn.base import TransformerMixin

HANDLED_FUNCTIONS = {}

AxesWarning = type("AxesWarning", (SyntaxWarning,), {})
BasicIndexer = Union[slice, int, type(Ellipsis), type(None)]
Indexer = BasicIndexer | np.ndarray
OldIndex = NewType("OldIndex", int)
Indexer = BasicIndexer | NDArray
StandardIndexer = Union[slice, int, type(None), NDArray]
OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent
KeyIndex = NewType("KeyIndex", int)
NewIndex = NewType("NewIndex", int)
# ListOrItem = list[T] | T
PartialReIndexer = tuple[KeyIndex, Optional[OldIndex], str]
CompleteReIndexer = tuple[
list[KeyIndex], Optional[list[OldIndex]], Optional[list[NewIndex]]
Expand Down Expand Up @@ -414,7 +415,7 @@ def concatenate(arrays, axis=0):

def standardize_indexer(
arr: np.ndarray, key: Indexer | Sequence[Indexer]
) -> tuple[tuple[Indexer], tuple[KeyIndex]]:
) -> tuple[tuple[StandardIndexer], tuple[KeyIndex]]:
"""Convert any legal numpy indexer to a "standard" form.
Standard form involves creating an equivalent indexer that is a tuple with
Expand Down Expand Up @@ -490,7 +491,7 @@ def _move_idxs_to_front(li: list, idxs: Sequence) -> None:


def _determine_adv_broadcasting(
key: Indexer | Sequence[Indexer], adv_inds: Sequence[OldIndex]
key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex]
) -> tuple:
"""Calculate the shape and location for the result of advanced indexing"""
adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:]))
Expand Down
14 changes: 14 additions & 0 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,17 @@ def test_squeeze_to_sublist():

with pytest.raises(ValueError, match="Indexes to squeeze"):
axes._squeeze_to_sublist(li, [0, 2])


def test_determine_adv_broadcasting():
indexers = (np.ones(1), np.ones((4, 1)), np.ones(3))
res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [0, 1, 2])
assert res_adj is True
assert res_nd == 2
assert res_start == 0

indexers = (None, np.ones(1), 2, np.ones(3))
res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3])
assert res_adj is False
assert res_nd == 1
assert res_start == 0

0 comments on commit c101a5a

Please sign in to comment.