From 449b41dad492e436888e0b1c41dd9d92784ac54a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 10 Dec 2024 20:20:34 -0800 Subject: [PATCH] Remove references to jax.core.raise_to_shaped As of JAX v0.4.36, `core.raise_to_shaped` is deprecated, and simply returns the input unchanged. PiperOrigin-RevId: 704944384 --- oryx/core/interpreters/harvest.py | 8 ++++---- oryx/core/interpreters/inverse/core.py | 1 - oryx/core/primitive.py | 3 +-- oryx/core/trace_util.py | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index 7d4d6d6..231e9af 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -510,7 +510,7 @@ def handle_sow(self, *values, name, tag, tree, mode): raise ValueError(f'Variable has already been reaped: {name}') avals = tree_util.tree_unflatten( tree, - [jax_core.raise_to_shaped(jax_core.get_aval(v)) for v in values]) + [jax_core.get_aval(v) for v in values]) vals = tree_util.tree_unflatten(tree, values) pred = None if mode == 'cond_clobber': @@ -792,7 +792,7 @@ def _get_harvest_metadata(closed_jaxpr, settings, *args): flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) in_avals = jax_util.safe_map( - lambda a: jax_core.raise_to_shaped(jax_core.get_aval(a)), + lambda a: jax_core.get_aval(a), flat_args) pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) metadata = aux() @@ -841,7 +841,7 @@ def _reap_scan_rule(trace: HarvestTrace, *vals, length, reverse, jaxpr, cond_carry_avals[name] = None if mode == 'cond_clobber': reap_carry_avals[name] = aval - cond_carry_avals[name] = jax_core.raise_to_shaped(jax_core.get_aval(True)) + cond_carry_avals[name] = jax_core.get_aval(True) body_fun = jax_core.jaxpr_as_fun(jaxpr) @@ -929,7 +929,7 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, ) reap_avals[k] = meta['aval'] if mode == 'cond_clobber': - cond_avals[k] = jax_core.raise_to_shaped(jax_core.get_aval(True)) + cond_avals[k] = jax_core.get_aval(True) cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) diff --git a/oryx/core/interpreters/inverse/core.py b/oryx/core/interpreters/inverse/core.py index eecfcbe..425e372 100644 --- a/oryx/core/interpreters/inverse/core.py +++ b/oryx/core/interpreters/inverse/core.py @@ -141,7 +141,6 @@ def unknown(cls, aval): def new(cls, val): val = np.array(val) aval = jax_core.get_aval(val) - aval = jax_core.raise_to_shaped(aval) ndslice = NDSlice.new(val, np.zeros_like(val)) return InverseAndILDJ(aval, frozenset([ndslice])) diff --git a/oryx/core/primitive.py b/oryx/core/primitive.py index 493dd18..cf1ec20 100644 --- a/oryx/core/primitive.py +++ b/oryx/core/primitive.py @@ -224,8 +224,7 @@ def subcall(self, name): tie_all_p = jax_core.Primitive('tie_all') tie_all_p.multiple_results = True tie_all_p.def_impl(lambda *args: args) -tie_all_p.def_abstract_eval(lambda *args: safe_map( # pylint: disable=g-long-lambda - jax_core.raise_to_shaped, args)) +tie_all_p.def_abstract_eval(lambda *args: args) mlir.register_lowering(tie_all_p, lambda c, *args: args) diff --git a/oryx/core/trace_util.py b/oryx/core/trace_util.py index 2ff61fe..c66db8a 100644 --- a/oryx/core/trace_util.py +++ b/oryx/core/trace_util.py @@ -42,7 +42,7 @@ def get_shaped_aval(x): if hasattr(x, 'dtype') and hasattr(x, 'shape'): return jax_core.ShapedArray( x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)) - return jax_core.raise_to_shaped(jax_core.get_aval(x)) + return jax_core.get_aval(x) def pv_like(x, abstract=True):