diff --git a/jaxparrow/tools/operators.py b/jaxparrow/tools/operators.py index b54b9e7..4b7ceb1 100644 --- a/jaxparrow/tools/operators.py +++ b/jaxparrow/tools/operators.py @@ -51,8 +51,8 @@ def axis0(pad_left): arr = lax.cond( pad_left, - lambda: field.at[1:, :].set(midpoint_values), - lambda: field.at[:-1, :].set(midpoint_values) + lambda: jnp.pad(midpoint_values, pad_width=((1, 0), (0, 0)), mode="edge"), + lambda: jnp.pad(midpoint_values, pad_width=((0, 1), (0, 0)), mode="edge") ) return arr @@ -64,8 +64,8 @@ def axis1(pad_left): arr = lax.cond( pad_left, - lambda: field.at[:, 1:].set(midpoint_values), - lambda: field.at[:, :-1].set(midpoint_values) + lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (1, 0)), mode="edge"), + lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (0, 1)), mode="edge") ) return arr