diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py
index 0cb08066f..6c040dee5 100644
--- a/benchmarks/nnx_simple_training.py
+++ b/benchmarks/nnx_simple_training.py
@@ -25,7 +25,9 @@
 from absl import app
 
 FLAGS = flags.FLAGS
-flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in')
+flags.DEFINE_enum(
+  'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
+)
 flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
 flags.DEFINE_integer('batch_size', 32, 'Batch size')
 flags.DEFINE_integer('width', 32, 'Hidden layer size')
@@ -46,6 +48,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
   def __call__(self, x):
     return x @ self.w + self.b
 
+class Block(nnx.Module):
+  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
+    self.linear = Linear(din, dout, rngs=rngs)
+    self.bn = nnx.BatchNorm(dout, rngs=rngs)
+
+  def __call__(self, x):
+    return nnx.relu(self.bn(self.linear(x)))
 
 class Count(nnx.Variable):
   pass
@@ -54,11 +63,11 @@ class Count(nnx.Variable):
 class MLP(nnx.Module):
   def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
     self.count = Count(jnp.array(0))
-    self.linear_in = Linear(din, dhidden, rngs=rngs)
+    self.linear_in = Block(din, dhidden, rngs=rngs)
     self.intermediates = [
-      Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
+      Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
     ]
-    self.linear_out = Linear(dhidden, dout, rngs=rngs)
+    self.linear_out = Block(dhidden, dout, rngs=rngs)
 
   def __call__(self, x):
     self.count.value += 1
@@ -79,18 +88,14 @@ def main(argv):
 
   print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')
 
-  if mode not in ['nnx', 'jax']:
-    raise ValueError(f'Invalid mode: {mode}')
-
   X = np.linspace(0, 1, 100)[:, None]
   Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
 
-  model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
-  tx = optax.sgd(1e-3)
-  optimizer = nnx.Optimizer(model, tx)
-  t0 = time()
-
-  if mode == 'nnx':
+  if mode == 'nnx' or mode == 'all':
+    model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
+    tx = optax.sgd(1e-3)
+    optimizer = nnx.Optimizer(model, tx)
+    t0 = time()
 
     @nnx.jit
     def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
@@ -115,11 +120,22 @@ def test_step_nnx(model: MLP, batch):
 
       if step % 1000 == 0:
         logs = test_step_nnx(model, (X, Y))
-        print(f"step: {step}, loss: {logs['loss']}")
 
       if step >= total_steps - 1:
         break
-  else:
+
+    print('### NNX ###')
+    print(f"final loss: {logs['loss']}")
+    total_time = time() - t0
+    print('total time:', total_time)
+    print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
+    print('times called:', model.count.value)
+
+  if mode == 'jax' or mode == 'all':
+    model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
+    tx = optax.sgd(1e-3)
+    optimizer = nnx.Optimizer(model, tx)
+    t0 = time()
 
     @jax.jit
     def train_step_jax(graphdef, state, batch):
@@ -151,17 +167,18 @@ def test_step_jax(graphdef, state, batch):
 
       if step % 1000 == 0:
         state, logs = test_step_jax(graphdef, state, (X, Y))
-        print(f"step: {step}, loss: {logs['loss']}")
 
       if step >= total_steps - 1:
         break
 
     model, optimizer = nnx.merge(graphdef, state)
 
-  total_time = time() - t0
-  print('total time:', total_time)
-  print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
-  print('times called:', model.count.value)
+    print('### JAX ###')
+    print(f"final loss: {logs['loss']}")
+    total_time = time() - t0
+    print('total time:', total_time)
+    print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
+    print('times called:', model.count.value)
 
 
 if __name__ == '__main__':
diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py
index 191a0c195..e5662e104 100644
--- a/flax/nnx/extract.py
+++ b/flax/nnx/extract.py
@@ -254,7 +254,7 @@ class GraphDefState(struct.PyTreeNode):
 
 class NodeStates(struct.PyTreeNode):
   _graphdef: graph.GraphDef[tp.Any] | None
-  states: tuple[graph.GraphState, ...]
+  states: tuple[graph.GraphState | graph.GraphFlatState, ...]
   metadata: tp.Any = struct.field(pytree_node=False)
 
   @property
@@ -264,7 +264,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
     return self._graphdef
 
   @property
-  def state(self) -> graph.GraphState:
+  def state(self) -> graph.GraphState | graph.GraphFlatState:
     if len(self.states) != 1:
       raise ValueError(
         f'Expected exactly one GraphDefState, got {len(self.states)}'
@@ -275,15 +275,19 @@ def state(self) -> graph.GraphState:
   def from_split(
     cls,
     graphdef: graph.GraphDef[tp.Any],
-    state: graph.GraphState,
+    state: graph.GraphState | graph.GraphFlatState,
     /,
-    *states: graph.GraphState,
+    *states: graph.GraphState | graph.GraphFlatState,
     metadata: tp.Any = None,
   ):
     return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)
 
   @classmethod
-  def from_states(cls, state: graph.GraphState, *states: graph.GraphState):
+  def from_states(
+    cls,
+    state: graph.GraphState | graph.GraphFlatState,
+    *states: graph.GraphState | graph.GraphFlatState,
+  ):
     return cls(_graphdef=None, states=(state, *states), metadata=None)
 
   @classmethod
diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py
index a29999d34..c18a710b3 100644
--- a/flax/nnx/graph.py
+++ b/flax/nnx/graph.py
@@ -30,7 +30,7 @@
   CallableProxy,
   DelayedAccessor,
 )
-from flax.nnx.statelib import State
+from flax.nnx.statelib import FlatState, State
 from flax.nnx import variablelib
 from flax.nnx.variablelib import Variable, VariableState
 from flax.typing import Key, PathParts, is_key_like
@@ -53,6 +53,7 @@
 StateLeaf = VariableState[tp.Any]
 NodeLeaf = Variable[tp.Any]
 GraphState = State[Key, StateLeaf]
+GraphFlatState = FlatState[StateLeaf]
 
 
 def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
@@ -377,7 +378,9 @@ def _apply(
       module = merge(self, state, *states)
       fn = accessor(module)
       out = fn(*args, **kwargs)
-      return out, flatten(module)
+      graphdef, flat_state = flatten(module)
+      state_ = State.from_flat_path(flat_state)
+      return out, (graphdef, state_)
 
     return CallableProxy(_apply, accessor)  # type: ignore
 
@@ -389,7 +392,7 @@ def _apply(
 
 def flatten(
   node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None
-) -> tuple[GraphDef[Node], GraphState]:
+) -> tuple[GraphDef[Node], FlatState[tp.Any]]:
   """Flattens a graph node into a (graphdef, state) pair.
 
   Args:
@@ -402,7 +405,7 @@ def flatten(
     ref_index = RefMap()
   flat_state: list[tuple[PathParts, StateLeaf]] = []
   graphdef = _graph_flatten((), ref_index, flat_state, node)
-  return graphdef, GraphState.from_flat_path(flat_state)
+  return graphdef, FlatState(flat_state)
 
 
 def _graph_flatten(
@@ -811,8 +814,11 @@ def split(
     ctx = (
       current_update_context(self.ctxtag) if self.ctxtag is not None else None
     )
-    graphdef, state = flatten(node, self.ref_index)
-    states = _split_state(state, filters)
+    graphdef, flat_state = flatten(node, self.ref_index)
+    flat_states = _split_state(flat_state, filters)
+    states = tuple(
+      State.from_flat_path(flat_state) for flat_state in flat_states
+    )
     if ctx is not None:
       if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
         index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
@@ -822,6 +828,47 @@ def split(
 
     return graphdef, *states
 
+  @tp.overload
+  def flatten(
+    self, graph_node: A, /
+  ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
+  @tp.overload
+  def flatten(
+    self, graph_node: A, first: filterlib.Filter, /
+  ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
+  @tp.overload
+  def flatten(
+    self,
+    graph_node: A,
+    first: filterlib.Filter,
+    second: filterlib.Filter,
+    /,
+    *filters: filterlib.Filter,
+  ) -> tuple[
+    GraphDef[A],
+    FlatState[VariableState[tp.Any]],
+    tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]],
+  ]: ...
+  def flatten(
+    self, node: A, *filters: filterlib.Filter
+  ) -> tuple[
+    GraphDef[A], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]]
+  ]:
+    ctx = (
+      current_update_context(self.ctxtag) if self.ctxtag is not None else None
+    )
+    graphdef, flat_state = flatten(node, self.ref_index)
+    flat_states = _split_state(flat_state, filters)
+
+    if ctx is not None:
+      if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
+        index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
+        graphdef = dataclasses.replace(
+          graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
+        )
+
+    return graphdef, *flat_states
+
 
 @contextlib.contextmanager
 def split_context(ctxtag: str | None = None):
@@ -874,6 +921,39 @@ def merge(
     )
     return node
 
+  def unflatten(
+    self,
+    graphdef: GraphDef[A],
+    flat_state: GraphFlatState,
+    /,
+    *flat_states: GraphFlatState,
+  ) -> A:
+    ctx = (
+      current_update_context(self.ctxtag) if self.ctxtag is not None else None
+    )
+    if (
+      ctx is not None
+      and isinstance(graphdef, NodeDef)
+      and graphdef.index_mapping is not None
+    ):
+      # outer merge (4), create index_ref_cache
+      assert ctx.ref_index is not None
+      index_ref_cache = compose_mapping_reversed(
+        ctx.ref_index, graphdef.index_mapping
+      )
+    else:
+      # inner merge (2)
+      index_ref_cache = None
+
+    state = FlatState.merge(flat_state, *flat_states).to_nested_state()
+    node = unflatten(
+      graphdef,
+      state,
+      index_ref=self.index_ref,
+      index_ref_cache=index_ref_cache,
+    )
+    return node
+
 
 @contextlib.contextmanager
 def merge_context(ctxtag: str | None = None):
@@ -1001,9 +1081,11 @@ def split(
       filters are passed, a single :class:`State` is returned.
     """
     ref_index: RefMap[tp.Any, Index] = RefMap()
-    graphdef, state = flatten(node, ref_index)
-    states = _split_state(state, filters)
-
+    graphdef, flat_state = flatten(node, ref_index)
+    states = tuple(
+      State.from_flat_path(flat_state)
+      for flat_state in _split_state(flat_state, filters)
+    )
     if self.index_ref is not None and isinstance(graphdef, NodeDef):
       index_to_index = compose_mapping(self.index_ref, ref_index)
       graphdef = dataclasses.replace(
@@ -1195,13 +1277,13 @@ def current_update_context(tag: str) -> UpdateContext:
 # --------------------------------------------------------
 
 def _split_state(
-  state: GraphState,
+  state: FlatState[tp.Any],
   filters: tuple[filterlib.Filter, ...],
-) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
+) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]:
   if not filters:
     return (state,)
   states = state.split(*filters)
-  if isinstance(states, State):
+  if not isinstance(states, tuple):
     return (states,)
   assert len(states) > 0
   return states  # type: ignore[return-value]
@@ -1292,9 +1374,11 @@ def split(
     ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
     filters are passed, a single ``State`` is returned.
   """
-  graphdef, state = flatten(node)
-  states = _split_state(state, filters)
-  return graphdef, *states
+  graphdef, flat_state = flatten(node)
+  flat_states = _split_state(flat_state, filters)
+  states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
+  return graphdef, *states  # type: ignore[return-value]
+
 
 def merge(
   graphdef: GraphDef[A],
@@ -1486,6 +1570,7 @@ def state(
     One or more :class:`State` mappings.
   """
   _, state = flatten(node)
+  state = state.to_nested_state()
 
   states: GraphState | tuple[GraphState, ...]
   if len(filters) == 0:
diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py
index 6ed7660cd..9a36c3865 100644
--- a/flax/nnx/reprlib.py
+++ b/flax/nnx/reprlib.py
@@ -111,6 +111,14 @@ def __nnx_repr__(self):
     for key, value in self.items():
       yield Attr(repr(key), value)
 
+class SequenceReprMixin(tp.Sequence[A], Representable):
+  def __nnx_repr__(self):
+    yield Object(type='', value_sep='', start='[', end=']')
+
+    for value in self:
+      yield Attr('', value)
+
+
 @dataclasses.dataclass(repr=False)
 class PrettyMapping(Representable):
   mapping: tp.Mapping
diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py
index 42a260404..1c1b1b512 100644
--- a/flax/nnx/statelib.py
+++ b/flax/nnx/statelib.py
@@ -54,7 +54,7 @@ def __treescope_repr__(self, path, subtree_renderer):
     # Render as the dictionary itself at the same path.
     return subtree_renderer(children, path=path)
 
-class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.PrettySequence):
+class FlatState(reprlib.SequenceReprMixin[tuple[PathParts, V]]):
   _keys: tuple[PathParts, ...]
   _values: list[V]
 
@@ -83,6 +83,85 @@ def __len__(self) -> int:
   def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]:
     return iter(zip(self._keys, self._values))
 
+  def to_nested_state(self) -> State[PathParts, V]:
+    return State.from_flat_path(self)
+
+  @tp.overload
+  def split(self, first: filterlib.Filter, /) -> FlatState[V]: ...
+
+  @tp.overload
+  def split(
+    self,
+    first: filterlib.Filter,
+    second: filterlib.Filter,
+    /,
+    *filters: filterlib.Filter,
+  ) -> tuple[FlatState[V], ...]: ...
+
+  @tp.overload
+  def split(
+    self, /, *filters: filterlib.Filter
+  ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: ...
+
+  def split(  # type: ignore[misc]
+    self, first: filterlib.Filter, /, *filters: filterlib.Filter
+  ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]:
+    filters = (first, *filters)
+    *flat_states_, rest = _split_state(self, *filters)
+
+    if rest:
+      raise ValueError(
+        'Non-exhaustive filters, got a non-empty remainder: '
+        f'{rest}.\nUse `...` to match all remaining elements.'
+      )
+
+    flat_states: FlatState[V] | tuple[FlatState[V], ...]
+    if len(flat_states_) == 1:
+      flat_states = flat_states_[0]
+    else:
+      flat_states = tuple(flat_states_)
+    return flat_states  # type: ignore
+
+  @tp.overload
+  def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ...
+
+  @tp.overload
+  def filter(
+    self,
+    first: filterlib.Filter,
+    second: filterlib.Filter,
+    /,
+    *filters: filterlib.Filter,
+  ) -> tuple[FlatState[V], ...]: ...
+
+  def filter(
+    self,
+    first: filterlib.Filter,
+    /,
+    *filters: filterlib.Filter,
+  ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]:
+    *flat_states_, _rest = _split_state(self, first, *filters)
+
+    assert len(flat_states_) == len(filters) + 1
+
+    flat_states: FlatState[V] | tuple[FlatState[V], ...]
+    if len(flat_states_) == 1:
+      flat_states = flat_states_[0]
+    else:
+      flat_states = tuple(flat_states_)
+
+    return flat_states  # type: ignore
+
+  @staticmethod
+  def merge(
+    flat_state: tp.Iterable[tuple[PathParts, V]],
+    /,
+    *flat_states: tp.Iterable[tuple[PathParts, V]],
+  ) -> FlatState[V]:
+    flat_states = (flat_state, *flat_states)
+
+    return FlatState(elem for flat_state in flat_states for elem in flat_state)
+
 
 def _flat_state_pytree_flatten(x: FlatState[V]):
   return x._values, x._keys
@@ -291,7 +370,8 @@ def split(  # type: ignore[misc]
       One or more ``States`` equal to the number of filters passed.
     """
     filters = (first, *filters)
-    *states_, rest = _split_state(self.flat_state(), *filters)
+    flat_states = _split_state(self.flat_state(), *filters)
+    *states_, rest = (state.to_nested_state() for state in flat_states)
 
     if rest:
       raise ValueError(
@@ -356,7 +436,8 @@ def filter(
     Returns:
       One or more ``States`` equal to the number of filters passed.
     """
-    *states_, _rest = _split_state(self.flat_state(), first, *filters)
+    flat_states = _split_state(self.flat_state(), first, *filters)
+    *states_, _rest = (state.to_nested_state() for state in flat_states)
 
     assert len(states_) == len(filters) + 1
 
@@ -456,7 +537,7 @@ def _state_unflatten(
 def _split_state(
   flat_state: FlatState[V],
   *filters: filterlib.Filter,
-) -> tuple[State[PathParts, V], ...]:
+) -> tuple[FlatState[V], ...]:
   for i, filter_ in enumerate(filters):
     if filter_ in (..., True) and i != len(filters) - 1:
       remaining_filters = filters[i + 1 :]
@@ -482,7 +563,7 @@ def _split_state(
       # if we didn't break, set leaf to last state
       flat_states[-1].append((path, value))  # type: ignore[index] # mypy is wrong here?
 
-  return tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
+  return tuple(FlatState(flat_state) for flat_state in flat_states)
 
 
 def create_path_filters(state: State):
diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py
index e5ce20f8e..d3420dd43 100644
--- a/flax/nnx/transforms/compilation.py
+++ b/flax/nnx/transforms/compilation.py
@@ -91,9 +91,15 @@ def __hash__(self):
 def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x):
   if isinstance(prefix, StateSharding):
     return extract.NodeStates.from_split(
-      *ctx.split(x, *prefix.filters), metadata=prefix
+      *ctx.flatten(x, *prefix.filters), metadata=prefix
     )
-  return extract.NodeStates.from_split(*ctx.split(x))
+  return extract.NodeStates.from_split(*ctx.flatten(x))
+
+
+def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any:
+  if not isinstance(leaf, extract.NodeStates):
+    raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}')
+  return ctx.unflatten(leaf.graphdef, *leaf.states)  # type: ignore
 
 
 @dataclasses.dataclass(eq=False)
@@ -107,7 +113,9 @@ def __post_init__(self):
     functools.update_wrapper(self, self.f)
 
   def __call__(self, *pure_args, **pure_kwargs):
-    args, kwargs = extract.from_tree((pure_args, pure_kwargs), ctxtag='jit')
+    args, kwargs = extract.from_tree(
+      (pure_args, pure_kwargs), merge_fn=_jit_merge_fn, ctxtag='jit'
+    )
 
     out = self.f(*args, **kwargs)
 
@@ -346,7 +354,9 @@ def jit_wrapper(*args, **kwargs):
       *pure_args, **pure_kwargs
     )
     _args_out, _kwargs_out, out = extract.from_tree(
-      (pure_args_out, pure_kwargs_out, pure_out), ctxtag='jit'
+      (pure_args_out, pure_kwargs_out, pure_out),
+      merge_fn=_jit_merge_fn,
+      ctxtag='jit',
     )
     return out
 
diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py
index 4752a9b7b..fb3a276e9 100644
--- a/flax/nnx/variablelib.py
+++ b/flax/nnx/variablelib.py
@@ -808,6 +808,7 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
     if 'on_remove_axis' in self._var_metadata:
       self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
 
+GraphVariableState = VariableState[VariableState[tp.Any]]
 
 def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
   metadata = tuple(x.get_metadata().items())
diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py
index a7bbf178c..85c4f2a4c 100644
--- a/tests/nnx/graph_utils_test.py
+++ b/tests/nnx/graph_utils_test.py
@@ -64,7 +64,8 @@ def test_flatten(self):
     g = [a, 3, a, nnx.Param(4)]
 
     refmap = nnx.graph.RefMap()
-    graphdef, state = nnx.graph.flatten(g, ref_index=refmap)
+    graphdef, flat_state = nnx.graph.flatten(g, ref_index=refmap)
+    state = flat_state.to_nested_state()
 
     state[0]['b'].raw_value = 2
     state[3].raw_value = 4
@@ -329,6 +330,7 @@ def f(m: Foo):
     ref_out_idx_out = nnx.graph.RefMap()
     graphdef: nnx.graph.GraphDef[Foo]
     graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out)
+    state = state.to_nested_state()
 
     @partial(jax.jit, static_argnums=(0,))
     def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
@@ -337,6 +339,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
       f(m)
       ref_in_idx_in = nnx.graph.RefMap[Any, int]()
       graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in)
+      state = state.to_nested_state()
       idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
       static_out = nnx.graph.Static((graphdef, idx_out_idx_in))
       return state, static_out
@@ -369,6 +372,7 @@ def f(m: Foo):
     ref_out_idx_out = nnx.graph.RefMap[Any, int]()
     graphdef: nnx.graph.GraphDef[Foo]
     graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out)
+    state = state.to_nested_state()
 
     @partial(jax.jit, static_argnums=(0,))
     def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
@@ -377,6 +381,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
       f(m)
       ref_in_idx_in = nnx.graph.RefMap[Any, int]()
       graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in)
+      state = state.to_nested_state()
       idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
       static_out = nnx.graph.Static((graphdef, idx_out_idx_in))
       return state, static_out
@@ -406,6 +411,7 @@ def f(m: Foo):
     ref_out_idx_out = nnx.graph.RefMap()
     graphdef: nnx.graph.GraphDef[Foo]
     graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out)
+    state = state.to_nested_state()
 
     @partial(jax.jit, static_argnums=(0,))
     def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
@@ -414,6 +420,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
       f(m)
       ref_in_idx_in = nnx.graph.RefMap[Any, int]()
       graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in)
+      state = state.to_nested_state()
       idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
       static_out = nnx.graph.Static((graphdef, idx_out_idx_in))
       return state, static_out
diff --git a/uv.lock b/uv.lock
index e08e2dbf5..fb61c0e0e 100644
--- a/uv.lock
+++ b/uv.lock
@@ -3,13 +3,13 @@ requires-python = ">=3.10"
 resolution-markers = [
     "python_full_version < '3.11' and platform_system == 'Darwin'",
     "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "python_full_version == '3.11.*' and platform_system == 'Darwin'",
     "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "python_full_version >= '3.12' and platform_system == 'Darwin'",
     "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
 ]
 
 [[package]]
@@ -641,7 +641,7 @@ source = { registry = "https://pypi.org/simple" }
 resolution-markers = [
     "python_full_version < '3.11' and platform_system == 'Darwin'",
     "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
 ]
 sdist = { url = "https://files.pythonhosted.org/packages/99/bc/cfb52b9e8531526604afe8666185d207e4f0cb9c6d90bc76f62fb8746804/etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350", size = 95695 }
 wheels = [
@@ -676,10 +676,10 @@ source = { registry = "https://pypi.org/simple" }
 resolution-markers = [
     "python_full_version == '3.11.*' and platform_system == 'Darwin'",
     "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "python_full_version >= '3.12' and platform_system == 'Darwin'",
     "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
 ]
 sdist = { url = "https://files.pythonhosted.org/packages/ba/49/d480aeb4fc441d933acce97261bea002234a45fb847599c9a93c31e51b2e/etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379", size = 101506 }
 wheels = [
@@ -1202,7 +1202,7 @@ name = "ipython"
 version = "8.26.0"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
-    { name = "colorama", marker = "sys_platform == 'win32'" },
+    { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
     { name = "decorator" },
     { name = "exceptiongroup", marker = "python_full_version < '3.11'" },
     { name = "jedi" },
@@ -1246,7 +1246,7 @@ wheels = [
 
 [[package]]
 name = "jax"
-version = "0.4.37"
+version = "0.4.38"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
     { name = "jaxlib" },
@@ -1255,14 +1255,14 @@ dependencies = [
     { name = "opt-einsum" },
     { name = "scipy" },
 ]
-sdist = { url = "https://files.pythonhosted.org/packages/50/30/ad7617a960c86782587540a179cef676962322d1e5411415b1aa24f02ce0/jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b", size = 1915966 }
+sdist = { url = "https://files.pythonhosted.org/packages/fb/e5/c4aa9644bb96b7f6747bd7c9f8cda7665ca5e194fa2542b2dea3ff730701/jax-0.4.38.tar.gz", hash = "sha256:43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8", size = 1930034 }
 wheels = [
-    { url = "https://files.pythonhosted.org/packages/5f/3f/6c5553baaa7faa3fa8bae8279b1e46cb54c7ce52360139eae53498786ea5/jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e", size = 2221192 },
+    { url = "https://files.pythonhosted.org/packages/22/49/b4418a7a892c0dd64442bbbeef54e1cdfe722dfc5a7bf0d611d3f5f90e99/jax-0.4.38-py3-none-any.whl", hash = "sha256:78987306f7041ea8500d99df1a17c33ed92620c2268c4c3677fb24e06712be64", size = 2236864 },
 ]
 
 [[package]]
 name = "jaxlib"
-version = "0.4.36"
+version = "0.4.38"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
     { name = "ml-dtypes" },
@@ -1270,26 +1270,26 @@ dependencies = [
     { name = "scipy" },
 ]
 wheels = [
-    { url = "https://files.pythonhosted.org/packages/23/8d/8a44618f3493f29d769b2b40778d24075689cc8697b98e2c43bafbe50edf/jaxlib-0.4.36-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d69f991833b6dca794767049843462805936c89553b136a8ebb8485334204457", size = 98648230 },
-    { url = "https://files.pythonhosted.org/packages/78/b8/207485eab566dcfbc29bb833714ac1ca47a1665ca605b1ff7d3d5dd2afbe/jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:807814c1ba3ec69cffaa93d3f90651c694a9b8a750b43832cc167ed590c821dd", size = 78553787 },
-    { url = "https://files.pythonhosted.org/packages/26/42/3c2b0dc86a17aafd8f46ba0e4388f39f55706ee25f6c463c3dadea7a71e2/jaxlib-0.4.36-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1bc27d9ae09549d7652eafe1fdb10c21546cd2fd02bb24a49a7e6208b69163b0", size = 84008742 },
-    { url = "https://files.pythonhosted.org/packages/b9/b2/29be712098342df10075fe085c0b39d783a579bd3325fb0d69c22712cf27/jaxlib-0.4.36-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3379f03a794d6a30b75765d2786f6e31052f364196fcd49aaae292a3c16f12ec", size = 100263041 },
-    { url = "https://files.pythonhosted.org/packages/63/a9/93404a2f1d59647749d4d6dbab7bee9f5a7bfaeb9ade25b7e66c0ca0949a/jaxlib-0.4.36-cp310-cp310-win_amd64.whl", hash = "sha256:63e575ac8a515dee8171dd4a88c460d538bbcc9d959cabc9781e961763678f84", size = 63270658 },
-    { url = "https://files.pythonhosted.org/packages/e4/7d/9394ff39af5c23bb98a241c33742a328df5a43c21d569855ea7e096aaf5e/jaxlib-0.4.36-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:213792db3b876206b45f6a9fbea15e4dd22a9e80be25b03136f20c94784fecfa", size = 98669744 },
-    { url = "https://files.pythonhosted.org/packages/34/5a/9f3c9e5cec23e60f78bb3c3da108a5ef664601862dbc4e84fc4be3654f5d/jaxlib-0.4.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d7a89adf4c9d3cddd20482931dedc7a9e2669e904196a9599d9a605b3d9e552", size = 78574312 },
-    { url = "https://files.pythonhosted.org/packages/ff/5c/bf78ed9b8d0f174a562f6496049a4872e14a3bb3a80de09c4292d04be5f0/jaxlib-0.4.36-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c395fe8cc5bd6558dd2fbce78e24172b6f27762e17628720ae03d693001283f3", size = 84038323 },
-    { url = "https://files.pythonhosted.org/packages/67/af/6a9dd26e8a6bedd4c9fe702059767256b0d9ed18c29a180a4598d5795bb4/jaxlib-0.4.36-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc324c6b1c64fe68400934c653e4e622f12576120dcdb451c3b4ea4dcaba2ae9", size = 100285487 },
-    { url = "https://files.pythonhosted.org/packages/b7/46/31c3a519a94e84c672ca264c4151998e3e3fd11c481d8fa5af5885b91a1e/jaxlib-0.4.36-cp311-cp311-win_amd64.whl", hash = "sha256:c9e0c45a79e63aea65447f82bd0fa21c17b9afe884aa18dd5362b9965abe9d72", size = 63308064 },
-    { url = "https://files.pythonhosted.org/packages/e3/0e/3b4a99c09431ee5820624d4dcf4efa7becd3c83b56ff0f09a078f4c421a2/jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d", size = 98718357 },
-    { url = "https://files.pythonhosted.org/packages/d3/46/05e70a1236ec3782333b3e9469f971c9d45af2aa0aebf602acd9d76292eb/jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e", size = 78596060 },
-    { url = "https://files.pythonhosted.org/packages/8e/76/6b969cbf197b8c53c84c2642069722e84a3a260af084a8acbbf90ca444ea/jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442", size = 84053202 },
-    { url = "https://files.pythonhosted.org/packages/fe/f2/7624a304426daa7b135b85caf1b8eccf879e7cb10bc074656ce628309cb0/jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357", size = 100325610 },
-    { url = "https://files.pythonhosted.org/packages/bb/8b/ded8420cd9198eb677869ffd557d9880af5833c7bf39e604e80b56550e09/jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3", size = 63338518 },
-    { url = "https://files.pythonhosted.org/packages/5d/22/b72811c61e8b594951d3ee03245cb0932c723ac35e75569005c3c976eec2/jaxlib-0.4.36-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:02df9c0e1323dde01e966c22eb12432905d2d4de8aac7b603cad2083101b0e6b", size = 98719384 },
-    { url = "https://files.pythonhosted.org/packages/f1/66/3f4a97097983914899100db9e5312493fe1d6adc924e47a0e47e15c553f5/jaxlib-0.4.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ec980e85983f41999c4dc84137dec70507d958e23d7eefa104da93053d135f", size = 78596150 },
-    { url = "https://files.pythonhosted.org/packages/3a/6f/cf02f56d1532962d8ca77a6548acab8204294b96b5a153ca4a2caf4971fc/jaxlib-0.4.36-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7ce9368515348d869d6c59d9904c3cb3c81f22ff3e9e969eae0e3563fe472080", size = 84055851 },
-    { url = "https://files.pythonhosted.org/packages/28/10/4fc4e9719c065c6455491730011e87fe4b5120a9a008161cc32663feb9ce/jaxlib-0.4.36-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:93f1c502d08e517f842fe7b18428bb086cfd077db0ea9a2418fb21e5b4e06d3d", size = 100325986 },
-    { url = "https://files.pythonhosted.org/packages/ba/28/fece5385e736ef2f1b5bed133f8001f0fc66dd0104707381343e047b341a/jaxlib-0.4.36-cp313-cp313-win_amd64.whl", hash = "sha256:bddf436a243e83ec6bc16bcbb74d15b1960a69318c9ea796fb2109492bc52575", size = 63338694 },
+    { url = "https://files.pythonhosted.org/packages/ee/d4/e6a0881a88b8f17491c2ee271fd77c348b0221d9e2ec92dad23a2c9e41bc/jaxlib-0.4.38-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:55c19b9d3f33a6fc59f644aa5a21fba02639ccdd776cb4a9b5526625f57839ff", size = 99663603 },
+    { url = "https://files.pythonhosted.org/packages/b6/6d/11569ce873f04c82ec22e58d822f4187dccae1d400c0d6dd05ed314d5328/jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:30b2f52cb50d74734af2f477c2533a7a583e3bb7b2c8acdeb361ee77d940577a", size = 79475708 },
+    { url = "https://files.pythonhosted.org/packages/72/61/1de2405d13089c83b1ad87ec0266479c9d00080659dae2474892ae356306/jaxlib-0.4.38-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ee19c163a8fdf0839d4c18b88a5fbfb4e731ba7c437416d3e5483e570bb764e4", size = 93219045 },
+    { url = "https://files.pythonhosted.org/packages/9c/24/0829decf233c6af9efe7c53888ae8ac72395e0979869cd9cee487e35dac3/jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:61aeccb9a27c67fdb8450f6357240019cd4511cb9d62a44e4764756d384853ad", size = 101732107 },
+    { url = "https://files.pythonhosted.org/packages/0d/04/120c4caac6151f7297fedf9dd776362aa2d417d3f87bda826050b4da45e8/jaxlib-0.4.38-cp310-cp310-win_amd64.whl", hash = "sha256:d6ab745a89d0fb737a36fe1d8b86659e3fffe6ee8303b20651b26193d5edc0ef", size = 64223924 },
+    { url = "https://files.pythonhosted.org/packages/b0/6a/b9fba73eb5e758e40a514919e096a039d27dc0ab4776a6cc977f5153a55f/jaxlib-0.4.38-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:b67fdeabd6dfed08b7768f3bdffb521160085f8305669bd197beef61d08de08b", size = 99679916 },
+    { url = "https://files.pythonhosted.org/packages/44/2a/3458130d44d44038fd6974e7c43948f68408f685063203b82229b9b72c1a/jaxlib-0.4.38-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb0eaae7369157afecbead50aaf29e73ffddfa77a2335d721bd9794f3c510e4", size = 79488377 },
+    { url = "https://files.pythonhosted.org/packages/94/96/7d9a0b9f35af4727df44b68ade4c6f15163840727d1cb47251b1ea515e30/jaxlib-0.4.38-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:43db58c4c427627296366a56c10318e1f00f503690e17f94bb4344293e1995e0", size = 93241543 },
+    { url = "https://files.pythonhosted.org/packages/a3/2d/68f85037e60c981b37b18b23ace458c677199dea4722ddce541b48ddfc63/jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2751ff7037d6a997d0be0e77cc4be381c5a9f9bb8b314edb755c13a6fd969f45", size = 101751923 },
+    { url = "https://files.pythonhosted.org/packages/cc/24/a9c571c8a189f58e0b54b14d53fc7f5a0a06e4f1d7ab9edcf8d1d91d07e7/jaxlib-0.4.38-cp311-cp311-win_amd64.whl", hash = "sha256:35226968fc9de6873d1571670eac4117f5ed80e955f7a1775204d1044abe16c6", size = 64255189 },
+    { url = "https://files.pythonhosted.org/packages/49/df/08b94c593c0867c7eaa334592807ba74495de4be90580f360db8b96221dc/jaxlib-0.4.38-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:3fefea985f0415816f3bbafd3f03a437050275ef9bac9a72c1314e1644ac57c1", size = 99737849 },
+    { url = "https://files.pythonhosted.org/packages/ab/b1/c9d2a7ba9ebeabb7ac37082f4c466364f475dc7550a79358c0f0aa89fdf2/jaxlib-0.4.38-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f33bcafe32c97a562ecf6894d7c41674c80c0acdedfa5423d49af51147149874", size = 79509242 },
+    { url = "https://files.pythonhosted.org/packages/53/25/dd670d8bdf3799ece76d12cfe6a6a250ea256057aa4b0fcace4753a99d2d/jaxlib-0.4.38-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:496f45b0e001a2341309cd0c74af0b670537dced79c168cb230cfcc773f0aa86", size = 93251503 },
+    { url = "https://files.pythonhosted.org/packages/f9/cc/37fce5162f6b9070203fd76cc0f298d9b3bfdf01939a78935a6078d63621/jaxlib-0.4.38-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:dad6c0a96567c06d083c0469fec40f201210b099365bd698be31a6d2ec88fd59", size = 101792792 },
+    { url = "https://files.pythonhosted.org/packages/6f/7a/8515950a60a4ea5b13cc98fc0a42e36553b2db5a6eedc00d3bd7836f77b5/jaxlib-0.4.38-cp312-cp312-win_amd64.whl", hash = "sha256:966cdec36cfa978f5b4582bcb4147fe511725b94c1a752dac3a5f52ce46b6fa3", size = 64288223 },
+    { url = "https://files.pythonhosted.org/packages/91/03/aee503c7077c6dbbd568842303426c6ec1cef9bff330c418c9e71906cccd/jaxlib-0.4.38-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:41e55ae5818a882e5789e848f6f16687ac132bcfbb5a5fa114a5d18b78d05f2d", size = 99739026 },
+    { url = "https://files.pythonhosted.org/packages/cb/bf/fbbf61da319611d88e11c691d5a2077039208ded05e1731dea940f824a59/jaxlib-0.4.38-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6fe326b8af366387dd47ccf312583b2b17fed12712c9b74a648b18a13cbdbabf", size = 79508735 },
+    { url = "https://files.pythonhosted.org/packages/e4/0b/8cbff0b6d62a4694351c49baf53b7ed8deb8a6854d129408c38158e11676/jaxlib-0.4.38-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:248cca3771ebf24b070f49701364ceada33e6139445b06c782cca5ac5ad92bf4", size = 93251882 },
+    { url = "https://files.pythonhosted.org/packages/15/57/7f0283273b69c417071bcd2f4c2ed076479ec5ffc22a647f13c21da8d071/jaxlib-0.4.38-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:2ce77ba8cda9259a4bca97afc1c722e4291a6c463a63f8d372c6edc85117d625", size = 101791137 },
+    { url = "https://files.pythonhosted.org/packages/de/de/d6c4d234cd426b97459cb070af90792b48643967a0d28641379ee9e10fc9/jaxlib-0.4.38-cp313-cp313-win_amd64.whl", hash = "sha256:4103db0b3a38a5dc132741237453c24d8547290a22079ba1b577d6c88c95300a", size = 64288459 },
 ]
 
 [[package]]
@@ -1431,7 +1431,7 @@ version = "5.7.2"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
     { name = "platformdirs" },
-    { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" },
+    { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
     { name = "traitlets" },
 ]
 sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 }
@@ -2095,7 +2095,7 @@ name = "nvidia-cudnn-cu12"
 version = "9.1.0.70"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
-    { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+    { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
 ]
 wheels = [
     { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -2122,9 +2122,9 @@ name = "nvidia-cusolver-cu12"
 version = "11.4.5.107"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
-    { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
-    { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
-    { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+    { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+    { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+    { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
 ]
 wheels = [
     { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
@@ -2135,7 +2135,7 @@ name = "nvidia-cusparse-cu12"
 version = "12.1.0.106"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
-    { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+    { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
 ]
 wheels = [
     { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
@@ -2262,7 +2262,7 @@ wheels = [
 
 [[package]]
 name = "orbax-checkpoint"
-version = "0.10.2"
+version = "0.10.3"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
     { name = "absl-py" },
@@ -2280,9 +2280,9 @@ dependencies = [
     { name = "tensorstore" },
     { name = "typing-extensions" },
 ]
-sdist = { url = "https://files.pythonhosted.org/packages/d1/06/c42e2f1563dbaaf5ed1464d7b634324fb9a2da04021073c45777e61af78d/orbax_checkpoint-0.10.2.tar.gz", hash = "sha256:e575ebe1f94e5cb6353ab8c9df81de0ca7cddc118645c3bfc17b8344f19d42f1", size = 248170 }
+sdist = { url = "https://files.pythonhosted.org/packages/87/fd/36b22046aecf155e50494fd7901ecd3e97e0db3ac103d3a0ffd0cafd2d9e/orbax_checkpoint-0.10.3.tar.gz", hash = "sha256:71e3ea47e38d571f27146ee55c8727d7e7c242cf3df31dc499f9b2cb1d67ac8a", size = 252556 }
 wheels = [
-    { url = "https://files.pythonhosted.org/packages/61/19/ed366f8894923f3c8db0370e4bdd57ef843d68011dafa00d8175f4a66e1a/orbax_checkpoint-0.10.2-py3-none-any.whl", hash = "sha256:dcfc425674bd8d4934986143bd22a37cd634d034652c5d30d83c539ef8587941", size = 354306 },
+    { url = "https://files.pythonhosted.org/packages/6d/45/12a80b3704ec7d46fb0f79d193f4a089aa4a8297a61e6db183d97d108a4b/orbax_checkpoint-0.10.3-py3-none-any.whl", hash = "sha256:df7fd5f327dfe9c477533f33c20076ae11ba6a15767c5117881b328dece14c7d", size = 359825 },
 ]
 
 [[package]]
@@ -2436,7 +2436,7 @@ source = { registry = "https://pypi.org/simple" }
 resolution-markers = [
     "python_full_version < '3.11' and platform_system == 'Darwin'",
     "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
 ]
 sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 }
 wheels = [
@@ -2454,10 +2454,10 @@ source = { registry = "https://pypi.org/simple" }
 resolution-markers = [
     "python_full_version == '3.11.*' and platform_system == 'Darwin'",
     "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
     "python_full_version >= '3.12' and platform_system == 'Darwin'",
     "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
-    "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+    "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
 ]
 sdist = { url = "https://files.pythonhosted.org/packages/e8/ab/cb61a4b87b2e7e6c312dce33602bd5884797fd054e0e53205f1c27cf0f66/protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d", size = 380283 }
 wheels = [
@@ -2606,7 +2606,7 @@ name = "pytest"
 version = "8.3.2"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
-    { name = "colorama", marker = "sys_platform == 'win32'" },
+    { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
     { name = "exceptiongroup", marker = "python_full_version < '3.11'" },
     { name = "iniconfig" },
     { name = "packaging" },
@@ -3195,7 +3195,7 @@ source = { registry = "https://pypi.org/simple" }
 dependencies = [
     { name = "alabaster" },
     { name = "babel" },
-    { name = "colorama", marker = "sys_platform == 'win32'" },
+    { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
     { name = "docutils" },
     { name = "imagesize" },
     { name = "jinja2" },
@@ -3684,7 +3684,7 @@ name = "triton"
 version = "3.0.0"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
-    { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+    { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
 ]
 wheels = [
     { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },