Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] flatten returns FlatState #4458

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions benchmarks/nnx_simple_training.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,9 @@
from absl import app

FLAGS = flags.FLAGS
flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in')
flags.DEFINE_enum(
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
@@ -46,6 +48,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
def __call__(self, x):
return x @ self.w + self.b

class Block(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))

class Count(nnx.Variable):
pass
@@ -54,11 +63,11 @@ class Count(nnx.Variable):
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Linear(din, dhidden, rngs=rngs)
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
self.linear_out = Linear(dhidden, dout, rngs=rngs)
self.linear_out = Block(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
@@ -79,18 +88,14 @@ def main(argv):

print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')

if mode not in ['nnx', 'jax']:
raise ValueError(f'Invalid mode: {mode}')

X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)

model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

if mode == 'nnx':
if mode == 'nnx' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@nnx.jit
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
@@ -115,11 +120,22 @@ def test_step_nnx(model: MLP, batch):

if step % 1000 == 0:
logs = test_step_nnx(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break
else:

print('### NNX ###')
print(f"final loss: {logs['loss']}")
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)

if mode == 'jax' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@jax.jit
def train_step_jax(graphdef, state, batch):
@@ -151,17 +167,18 @@ def test_step_jax(graphdef, state, batch):

if step % 1000 == 0:
state, logs = test_step_jax(graphdef, state, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break

model, optimizer = nnx.merge(graphdef, state)

total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)
print('### JAX ###')
print(f"final loss: {logs['loss']}")
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)


if __name__ == '__main__':
14 changes: 9 additions & 5 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
@@ -254,7 +254,7 @@ class GraphDefState(struct.PyTreeNode):

class NodeStates(struct.PyTreeNode):
_graphdef: graph.GraphDef[tp.Any] | None
states: tuple[graph.GraphState, ...]
states: tuple[graph.GraphState | graph.GraphFlatState, ...]
metadata: tp.Any = struct.field(pytree_node=False)

@property
@@ -264,7 +264,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
return self._graphdef

@property
def state(self) -> graph.GraphState:
def state(self) -> graph.GraphState | graph.GraphFlatState:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.states)}'
@@ -275,15 +275,19 @@ def state(self) -> graph.GraphState:
def from_split(
cls,
graphdef: graph.GraphDef[tp.Any],
state: graph.GraphState,
state: graph.GraphState | graph.GraphFlatState,
/,
*states: graph.GraphState,
*states: graph.GraphState | graph.GraphFlatState,
metadata: tp.Any = None,
):
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)

@classmethod
def from_states(cls, state: graph.GraphState, *states: graph.GraphState):
def from_states(
cls,
state: graph.GraphState | graph.GraphFlatState,
*states: graph.GraphState | graph.GraphFlatState,
):
return cls(_graphdef=None, states=(state, *states), metadata=None)

@classmethod
115 changes: 100 additions & 15 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
CallableProxy,
DelayedAccessor,
)
from flax.nnx.statelib import State
from flax.nnx.statelib import FlatState, State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like
@@ -53,6 +53,7 @@
StateLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
@@ -377,7 +378,9 @@ def _apply(
module = merge(self, state, *states)
fn = accessor(module)
out = fn(*args, **kwargs)
return out, flatten(module)
graphdef, flat_state = flatten(module)
state_ = State.from_flat_path(flat_state)
return out, (graphdef, state_)

return CallableProxy(_apply, accessor) # type: ignore

@@ -389,7 +392,7 @@ def _apply(

def flatten(
node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None
) -> tuple[GraphDef[Node], GraphState]:
) -> tuple[GraphDef[Node], FlatState[tp.Any]]:
"""Flattens a graph node into a (graphdef, state) pair.
Args:
@@ -402,7 +405,7 @@ def flatten(
ref_index = RefMap()
flat_state: list[tuple[PathParts, StateLeaf]] = []
graphdef = _graph_flatten((), ref_index, flat_state, node)
return graphdef, GraphState.from_flat_path(flat_state)
return graphdef, FlatState(flat_state)


def _graph_flatten(
@@ -811,8 +814,11 @@ def split(
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, state = flatten(node, self.ref_index)
states = _split_state(state, filters)
graphdef, flat_state = flatten(node, self.ref_index)
flat_states = _split_state(flat_state, filters)
states = tuple(
State.from_flat_path(flat_state) for flat_state in flat_states
)
if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
@@ -822,6 +828,47 @@ def split(

return graphdef, *states

@tp.overload
def flatten(
self, graph_node: A, /
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
@tp.overload
def flatten(
self, graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
@tp.overload
def flatten(
self,
graph_node: A,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[
GraphDef[A],
FlatState[VariableState[tp.Any]],
tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]],
]: ...
def flatten(
self, node: A, *filters: filterlib.Filter
) -> tuple[
GraphDef[A], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]]
]:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, flat_state = flatten(node, self.ref_index)
flat_states = _split_state(flat_state, filters)

if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
)

return graphdef, *flat_states


@contextlib.contextmanager
def split_context(ctxtag: str | None = None):
@@ -874,6 +921,39 @@ def merge(
)
return node

def unflatten(
self,
graphdef: GraphDef[A],
flat_state: GraphFlatState,
/,
*flat_states: GraphFlatState,
) -> A:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
if (
ctx is not None
and isinstance(graphdef, NodeDef)
and graphdef.index_mapping is not None
):
# outer merge (4), create index_ref_cache
assert ctx.ref_index is not None
index_ref_cache = compose_mapping_reversed(
ctx.ref_index, graphdef.index_mapping
)
else:
# inner merge (2)
index_ref_cache = None

state = FlatState.merge(flat_state, *flat_states).to_nested_state()
node = unflatten(
graphdef,
state,
index_ref=self.index_ref,
index_ref_cache=index_ref_cache,
)
return node


@contextlib.contextmanager
def merge_context(ctxtag: str | None = None):
@@ -1001,9 +1081,11 @@ def split(
filters are passed, a single :class:`State` is returned.
"""
ref_index: RefMap[tp.Any, Index] = RefMap()
graphdef, state = flatten(node, ref_index)
states = _split_state(state, filters)

graphdef, flat_state = flatten(node, ref_index)
states = tuple(
State.from_flat_path(flat_state)
for flat_state in _split_state(flat_state, filters)
)
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
@@ -1195,13 +1277,13 @@ def current_update_context(tag: str) -> UpdateContext:
# --------------------------------------------------------

def _split_state(
state: GraphState,
state: FlatState[tp.Any],
filters: tuple[filterlib.Filter, ...],
) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]:
if not filters:
return (state,)
states = state.split(*filters)
if isinstance(states, State):
if not isinstance(states, tuple):
return (states,)
assert len(states) > 0
return states # type: ignore[return-value]
@@ -1292,9 +1374,11 @@ def split(
``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
filters are passed, a single ``State`` is returned.
"""
graphdef, state = flatten(node)
states = _split_state(state, filters)
return graphdef, *states
graphdef, flat_state = flatten(node)
flat_states = _split_state(flat_state, filters)
states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
return graphdef, *states # type: ignore[return-value]


def merge(
graphdef: GraphDef[A],
@@ -1486,6 +1570,7 @@ def state(
One or more :class:`State` mappings.
"""
_, state = flatten(node)
state = state.to_nested_state()

states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
8 changes: 8 additions & 0 deletions flax/nnx/reprlib.py
Original file line number Diff line number Diff line change
@@ -111,6 +111,14 @@ def __nnx_repr__(self):
for key, value in self.items():
yield Attr(repr(key), value)

class SequenceReprMixin(tp.Sequence[A], Representable):
def __nnx_repr__(self):
yield Object(type='', value_sep='', start='[', end=']')

for value in self:
yield Attr('', value)


@dataclasses.dataclass(repr=False)
class PrettyMapping(Representable):
mapping: tp.Mapping
91 changes: 86 additions & 5 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
@@ -54,7 +54,7 @@ def __treescope_repr__(self, path, subtree_renderer):
# Render as the dictionary itself at the same path.
return subtree_renderer(children, path=path)

class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.PrettySequence):
class FlatState(reprlib.SequenceReprMixin[tuple[PathParts, V]]):
_keys: tuple[PathParts, ...]
_values: list[V]

@@ -83,6 +83,85 @@ def __len__(self) -> int:
def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]:
return iter(zip(self._keys, self._values))

def to_nested_state(self) -> State[PathParts, V]:
return State.from_flat_path(self)

@tp.overload
def split(self, first: filterlib.Filter, /) -> FlatState[V]: ...

@tp.overload
def split(
self,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[FlatState[V], ...]: ...

@tp.overload
def split(
self, /, *filters: filterlib.Filter
) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: ...

def split( # type: ignore[misc]
self, first: filterlib.Filter, /, *filters: filterlib.Filter
) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]:
filters = (first, *filters)
*flat_states_, rest = _split_state(self, *filters)

if rest:
raise ValueError(
'Non-exhaustive filters, got a non-empty remainder: '
f'{rest}.\nUse `...` to match all remaining elements.'
)

flat_states: FlatState[V] | tuple[FlatState[V], ...]
if len(flat_states_) == 1:
flat_states = flat_states_[0]
else:
flat_states = tuple(flat_states_)
return flat_states # type: ignore

@tp.overload
def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ...

@tp.overload
def filter(
self,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[FlatState[V], ...]: ...

def filter(
self,
first: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]:
*flat_states_, _rest = _split_state(self, first, *filters)

assert len(flat_states_) == len(filters) + 1

flat_states: FlatState[V] | tuple[FlatState[V], ...]
if len(flat_states_) == 1:
flat_states = flat_states_[0]
else:
flat_states = tuple(flat_states_)

return flat_states # type: ignore

@staticmethod
def merge(
flat_state: tp.Iterable[tuple[PathParts, V]],
/,
*flat_states: tp.Iterable[tuple[PathParts, V]],
) -> FlatState[V]:
flat_states = (flat_state, *flat_states)

return FlatState(elem for flat_state in flat_states for elem in flat_state)


def _flat_state_pytree_flatten(x: FlatState[V]):
return x._values, x._keys
@@ -291,7 +370,8 @@ def split( # type: ignore[misc]
One or more ``States`` equal to the number of filters passed.
"""
filters = (first, *filters)
*states_, rest = _split_state(self.flat_state(), *filters)
flat_states = _split_state(self.flat_state(), *filters)
*states_, rest = (state.to_nested_state() for state in flat_states)

if rest:
raise ValueError(
@@ -356,7 +436,8 @@ def filter(
Returns:
One or more ``States`` equal to the number of filters passed.
"""
*states_, _rest = _split_state(self.flat_state(), first, *filters)
flat_states = _split_state(self.flat_state(), first, *filters)
*states_, _rest = (state.to_nested_state() for state in flat_states)

assert len(states_) == len(filters) + 1

@@ -456,7 +537,7 @@ def _state_unflatten(
def _split_state(
flat_state: FlatState[V],
*filters: filterlib.Filter,
) -> tuple[State[PathParts, V], ...]:
) -> tuple[FlatState[V], ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
@@ -482,7 +563,7 @@ def _split_state(
# if we didn't break, set leaf to last state
flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here?

return tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
return tuple(FlatState(flat_state) for flat_state in flat_states)


def create_path_filters(state: State):
18 changes: 14 additions & 4 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
@@ -91,9 +91,15 @@ def __hash__(self):
def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x):
if isinstance(prefix, StateSharding):
return extract.NodeStates.from_split(
*ctx.split(x, *prefix.filters), metadata=prefix
*ctx.flatten(x, *prefix.filters), metadata=prefix
)
return extract.NodeStates.from_split(*ctx.split(x))
return extract.NodeStates.from_split(*ctx.flatten(x))


def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any:
if not isinstance(leaf, extract.NodeStates):
raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}')
return ctx.unflatten(leaf.graphdef, *leaf.states) # type: ignore


@dataclasses.dataclass(eq=False)
@@ -107,7 +113,9 @@ def __post_init__(self):
functools.update_wrapper(self, self.f)

def __call__(self, *pure_args, **pure_kwargs):
args, kwargs = extract.from_tree((pure_args, pure_kwargs), ctxtag='jit')
args, kwargs = extract.from_tree(
(pure_args, pure_kwargs), merge_fn=_jit_merge_fn, ctxtag='jit'
)

out = self.f(*args, **kwargs)

@@ -346,7 +354,9 @@ def jit_wrapper(*args, **kwargs):
*pure_args, **pure_kwargs
)
_args_out, _kwargs_out, out = extract.from_tree(
(pure_args_out, pure_kwargs_out, pure_out), ctxtag='jit'
(pure_args_out, pure_kwargs_out, pure_out),
merge_fn=_jit_merge_fn,
ctxtag='jit',
)
return out

1 change: 1 addition & 0 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
@@ -808,6 +808,7 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)

GraphVariableState = VariableState[VariableState[tp.Any]]

def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
metadata = tuple(x.get_metadata().items())
9 changes: 8 additions & 1 deletion tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
@@ -64,7 +64,8 @@ def test_flatten(self):
g = [a, 3, a, nnx.Param(4)]

refmap = nnx.graph.RefMap()
graphdef, state = nnx.graph.flatten(g, ref_index=refmap)
graphdef, flat_state = nnx.graph.flatten(g, ref_index=refmap)
state = flat_state.to_nested_state()

state[0]['b'].raw_value = 2
state[3].raw_value = 4
@@ -329,6 +330,7 @@ def f(m: Foo):
ref_out_idx_out = nnx.graph.RefMap()
graphdef: nnx.graph.GraphDef[Foo]
graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out)
state = state.to_nested_state()

@partial(jax.jit, static_argnums=(0,))
def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
@@ -337,6 +339,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
f(m)
ref_in_idx_in = nnx.graph.RefMap[Any, int]()
graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in)
state = state.to_nested_state()
idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
static_out = nnx.graph.Static((graphdef, idx_out_idx_in))
return state, static_out
@@ -369,6 +372,7 @@ def f(m: Foo):
ref_out_idx_out = nnx.graph.RefMap[Any, int]()
graphdef: nnx.graph.GraphDef[Foo]
graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out)
state = state.to_nested_state()

@partial(jax.jit, static_argnums=(0,))
def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
@@ -377,6 +381,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
f(m)
ref_in_idx_in = nnx.graph.RefMap[Any, int]()
graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in)
state = state.to_nested_state()
idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
static_out = nnx.graph.Static((graphdef, idx_out_idx_in))
return state, static_out
@@ -406,6 +411,7 @@ def f(m: Foo):
ref_out_idx_out = nnx.graph.RefMap()
graphdef: nnx.graph.GraphDef[Foo]
graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out)
state = state.to_nested_state()

@partial(jax.jit, static_argnums=(0,))
def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
@@ -414,6 +420,7 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
f(m)
ref_in_idx_in = nnx.graph.RefMap[Any, int]()
graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in)
state = state.to_nested_state()
idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
static_out = nnx.graph.Static((graphdef, idx_out_idx_in))
return state, static_out
92 changes: 46 additions & 46 deletions uv.lock

Large diffs are not rendered by default.