From a331bcc2d83946e86718581e71b171de32902848 Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 6 Sep 2024 11:20:00 +0200 Subject: [PATCH] some comments cleaning --- jaxparrow/tools/operators.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxparrow/tools/operators.py b/jaxparrow/tools/operators.py index b54b9e7..4b7ceb1 100644 --- a/jaxparrow/tools/operators.py +++ b/jaxparrow/tools/operators.py @@ -51,8 +51,8 @@ def axis0(pad_left): arr = lax.cond( pad_left, - lambda: field.at[1:, :].set(midpoint_values), - lambda: field.at[:-1, :].set(midpoint_values) + lambda: jnp.pad(midpoint_values, pad_width=((1, 0), (0, 0)), mode="edge"), + lambda: jnp.pad(midpoint_values, pad_width=((0, 1), (0, 0)), mode="edge") ) return arr @@ -64,8 +64,8 @@ def axis1(pad_left): arr = lax.cond( pad_left, - lambda: field.at[:, 1:].set(midpoint_values), - lambda: field.at[:, :-1].set(midpoint_values) + lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (1, 0)), mode="edge"), + lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (0, 1)), mode="edge") ) return arr