graph#

flax.nnx.split(node, *filters)[source]#

Split a graph node into a GraphDef and one or more State`s. State is a ``Mapping` from strings or integers to Variables, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a Module graph, it is analogous to JAX’s PyTreeDef. split() is used in conjunction with merge() to switch seamlessly between stateful and stateless representations of the graph.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> jax.tree.map(jnp.shape, params)
State({
  'batch_norm': {
    'bias': Param(
      value=(2,)
    ),
    'scale': Param(
      value=(2,)
    )
  },
  'linear': {
    'bias': Param(
      value=(3,)
    ),
    'kernel': Param(
      value=(2, 3)
    )
  }
})
>>> jax.tree.map(jnp.shape, batch_stats)
State({
  'batch_norm': {
    'mean': BatchStat(
      value=(2,)
    ),
    'var': BatchStat(
      value=(2,)
    )
  }
})

split() and merge() are primarily used to interact directly with JAX transformations, see Functional API for more information.

Parameters
  • node – graph node to split.

  • *filters – some optional filters to group the state into mutually exclusive substates.

Returns

GraphDef and one or more States equal to the number of filters passed. If no filters are passed, a single State is returned.

flax.nnx.merge(graphdef, state, /, *states, copy=False)[source]#

The inverse of flax.nnx.split().

nnx.merge takes a flax.nnx.GraphDef and one or more flax.nnx.State’s and creates a new node with the same structure as the original node.

Recall: flax.nnx.split() is used to represent a flax.nnx.Module by: 1) a static nnx.GraphDef that captures its Pythonic static information; and 2) one or more flax.nnx.Variable nnx.State’(s) that capture its jax.Array’s in the form of JAX pytrees.

nnx.merge is used in conjunction with nnx.split to switch seamlessly between stateful and stateless representations of the graph.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> new_node = nnx.merge(graphdef, params, batch_stats)
>>> assert isinstance(new_node, Foo)
>>> assert isinstance(new_node.batch_norm, nnx.BatchNorm)
>>> assert isinstance(new_node.linear, nnx.Linear)

nnx.split and nnx.merge are primarily used to interact directly with JAX transformations (refer to Functional API for more information.

Parameters
Returns

The merged flax.nnx.Module.

flax.nnx.update(node, state, /, *states)[source]#

Update the given graph node with a new state(s) in-place.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))

>>> def loss_fn(model, x, y):
...   return jnp.mean((y - model(x))**2)
>>> prev_loss = loss_fn(model, x, y)

>>> grads = nnx.grad(loss_fn)(model, x, y)
>>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads)
>>> nnx.update(model, new_state)
>>> assert loss_fn(model, x, y) < prev_loss
Parameters
  • node – A graph node to update.

  • state – A State object.

  • *states – Additional State objects.

flax.nnx.pop(node, *filters)[source]#

Pop one or more Variable types from the graph node.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')
>>> y = model(x)
>>> assert hasattr(model, 'i')

>>> intermediates = nnx.pop(model, nnx.Intermediate)
>>> assert intermediates['i'].value[0].shape == (1, 3)
>>> assert not hasattr(model, 'i')
Parameters
  • node – A graph node object.

  • *filters – One or more Variable objects to filter by.

Returns

The popped State containing the Variable objects that were filtered for.

flax.nnx.state(node, *filters)[source]#

Similar to split() but only returns the State’s indicated by the filters.

Example usage:

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batch_norm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> # get the learnable parameters from the batch norm and linear layer
>>> params = nnx.state(model, nnx.Param)
>>> # get the batch statistics from the batch norm layer
>>> batch_stats = nnx.state(model, nnx.BatchStat)
>>> # get them separately
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> # get them together
>>> state = nnx.state(model)
Parameters
  • node – A graph node object.

  • *filters – One or more Variable objects to filter by.

Returns

One or more State mappings.

flax.nnx.variables(node, *filters)#

Similar to split() but only returns the State’s indicated by the filters.

Example usage:

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batch_norm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> # get the learnable parameters from the batch norm and linear layer
>>> params = nnx.state(model, nnx.Param)
>>> # get the batch statistics from the batch norm layer
>>> batch_stats = nnx.state(model, nnx.BatchStat)
>>> # get them separately
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> # get them together
>>> state = nnx.state(model)
Parameters
  • node – A graph node object.

  • *filters – One or more Variable objects to filter by.

Returns

One or more State mappings.

flax.nnx.graph()#
flax.nnx.graphdef(node, /)[source]#

Get the GraphDef of the given graph node.

Example usage:

>>> from flax import nnx

>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> graphdef, _ = nnx.split(model)
>>> assert graphdef == nnx.graphdef(model)
Parameters

node – A graph node object.

Returns

The GraphDef of the Module object.

flax.nnx.iter_graph(node, /)[source]#

Iterates over all nested nodes and leaves of the given graph node, including the current node.

iter_graph creates a generator that yields path and value pairs, where the path is a tuple of strings or integers representing the path to the value from the root. Repeated nodes are visited only once. Leaves include static values.

Example:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Linear(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.din, self.dout = din, dout
...     self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...
>>> module = Linear(3, 4, rngs=nnx.Rngs(0))
>>> graph = [module, module]
...
>>> for path, value in nnx.iter_graph(graph):
...   print(path, type(value).__name__)
...
(0, '_pytree__nodes') HashableMapping
(0, '_pytree__state') PytreeState
(0, 'b') Param
(0, 'din') int
(0, 'dout') int
(0, 'w') Param
(0,) Linear
() list
flax.nnx.clone(node)[source]#

Create a deep copy of the given graph node.

Example usage:

>>> from flax import nnx

>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> cloned_model = nnx.clone(model)
>>> model.bias.value += 1
>>> assert (model.bias.value != cloned_model.bias.value).all()
Parameters

node – A graph node object.

Returns

A deep copy of the Module object.

flax.nnx.call(graphdef_state, /)[source]#

Calls a method underlying graph node defined by a (GraphDef, State) pair.

call takes a (GraphDef, State) pair and creates a proxy object that can be used to call methods on the underlying graph node. When a method is called, the output is returned along with a new (GraphDef, State) pair that represents the updated state of the graph node. call is equivalent to merge() > method > split`() but is more convenient to use in pure JAX functions.

Example:

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> class StatefulLinear(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...     self.count = Variable(jnp.array(0, dtype=jnp.uint32))
...
...   def increment(self):
...     self.count += 1
...
...   def __call__(self, x):
...     self.increment()
...     return x @ self.w + self.b
...
>>> linear = StatefulLinear(3, 2, nnx.Rngs(0))
>>> linear_state = nnx.split(linear)
...
>>> @jax.jit
... def forward(x, linear_state):
...   y, linear_state = nnx.call(linear_state)(x)
...   return y, linear_state
...
>>> x = jnp.ones((1, 3))
>>> y, linear_state = forward(x, linear_state)
>>> y, linear_state = forward(x, linear_state)
...
>>> linear = nnx.merge(*linear_state)
>>> linear.count.value
Array(2, dtype=uint32)

The proxy object returned by call supports indexing and attribute access to access nested methods. In the example below, the increment method indexing is used to call the increment method of the StatefulLinear module at the b key of a nodes dictionary.

>>> class StatefulLinear(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...     self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
...
...   def increment(self):
...     self.count += 1
...
...   def __call__(self, x):
...     self.increment()
...     return x @ self.w + self.b
...
>>> rngs = nnx.Rngs(0)
>>> nodes = dict(
...   a=StatefulLinear(3, 2, rngs),
...   b=StatefulLinear(2, 1, rngs),
... )
...
>>> node_state = nnx.split(nodes)
>>> # use attribute access
>>> _, node_state = nnx.call(node_state)['b'].increment()
...
>>> nodes = nnx.merge(*node_state)
>>> nodes['a'].count.value
Array(0, dtype=uint32)
>>> nodes['b'].count.value
Array(1, dtype=uint32)
flax.nnx.cached_partial(f, *cached_args)#

Create a partial from a NNX transformed function alog with some cached input arguments and reduces the python overhead by caching the traversal of NNX graph nodes. This is useful for speed up function that are called repeatedly with the same subset of inputs e.g. a train_step with a model and optimizer:

>>> from flax import nnx
>>> import jax.numpy as jnp
>>> import optax
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param)
...
>>> @nnx.jit
... def train_step(model, optimizer, x, y):
...   def loss_fn(model):
...     return jnp.mean((model(x) - y) ** 2)
...
...   loss, grads = nnx.value_and_grad(loss_fn)(model)
...   optimizer.update(model, grads)
...   return loss
...
>>> cached_train_step = nnx.cached_partial(train_step, model, optimizer)
...
>>> for step in range(total_steps:=2):
...   x, y = jnp.ones((10, 2)), jnp.ones((10, 3))
...   # loss = train_step(model, optimizer, x, y)
...   loss = cached_train_step(x, y)
...   print(f'Step {step}: loss={loss:.3f}')
Step 0: loss=2.669
Step 1: loss=2.660

Note that cached_partial will clone all cached graph nodes to gurantee the validity of the cache, and these clones will contain references to the same Variable objects which guarantees that state is propagated correctly back to the original graph nodes. Because of the previous, the final structure of all graph nodes must be the same after each call to the cached function, otherswise an error will be raised. Temporary mutations are allowed (e.g. the use of Module.sow) as long as they are cleaned up before the function returns (e.g. via nnx.pop).

Parameters
  • f – A function to cache.

  • *cached_args – A subset of the input arguments containing the graph nodes to cache.

Returns

A partial function expecting the remaining arguments to the original function.

class flax.nnx.GraphDef(nodes: 'list[NodeDefType[tp.Any]]', attributes: 'list[tuple[Key, AttrType]]', num_leaves: 'int')[source]#
class flax.nnx.UpdateContext(tag, outer_ref_outer_index, outer_index_inner_ref, outer_index_outer_ref, inner_ref_outer_index, static_cache)[source]#

A context manager for handling complex state updates.

flax.nnx.update_context(tag)[source]#

Creates an UpdateContext context manager which can be used to handle more complex state updates beyond what nnx.update can handle, including updates to static properties and graph structure.

UpdateContext exposes a split and merge API with the same signature as nnx.split / nnx.merge but performs some bookkeeping to have the necessary information in order to perfectly update the input objects based on the changes made inside the transform. The UpdateContext must call split and merge a total of 4 times, the first and last calls happen outside the transform and the second and third calls happen inside the transform as shown in the diagram below:

                      idxmap
(2) merge ─────────────────────────────► split (3)
      ▲                                    │
      │               inside               │
      │. . . . . . . . . . . . . . . . . . │ index_mapping
      │               outside              │
      │                                    ▼
(1) split──────────────────────────────► merge (4)
                      refmap

The first call to split (1) creates a refmap which keeps track of the outer references, and the first call to merge (2) creates an idxmap which keeps track of the inner references. The second call to split (3) combines the refmap and idxmap to produce the index_mapping which indicates how the outer references map to the inner references. Finally, the last call to merge (4) uses the index_mapping and the refmap to reconstruct the output of the transform while reusing/updating the inner references. To avoid memory leaks, the idxmap is cleared after (3) and the refmap is cleared after (4), and both are cleared after the context manager exits.

Here is a simple example showing the use of update_context:

>>> from flax import nnx
...
>>> class Foo(nnx.Module): pass
...
>>> m1 = Foo()
>>> with nnx.update_context('example'):
...   with nnx.split_context('example') as ctx:
...     graphdef, state = ctx.split(m1)
...   @jax.jit
...   def f(graphdef, state):
...     with nnx.merge_context('example', inner=True) as ctx:
...       m2 = ctx.merge(graphdef, state)
...     m2.a = 1
...     m2.ref = m2  # create a reference cycle
...     with nnx.split_context('example') as ctx:
...       return ctx.split(m2)
...   graphdef_out, state_out = f(graphdef, state)
...   with nnx.merge_context('example', inner=False) as ctx:
...     m3 = ctx.merge(graphdef_out, state_out)
...
>>> assert m1 is m3
>>> assert m1.a == 1
>>> assert m1.ref is m1

Note that update_context takes in a tag argument which is used primarily as a safety mechanism reduce the risk of accidentally using the wrong UpdateContext when using current_update_context() to access the current active context. update_context can also be used as a decorator that creates/activates an UpdateContext context for the duration of the function:

>>> from flax import nnx
...
>>> class Foo(nnx.Module): pass
...
>>> m1 = Foo()
>>> @jax.jit
... def f(graphdef, state):
...   with nnx.merge_context('example', inner=True) as ctx:
...     m2 = ctx.merge(graphdef, state)
...   m2.a = 1     # insert static attribute
...   m2.ref = m2  # create a reference cycle
...   with nnx.split_context('example') as ctx:
...     return ctx.split(m2)
...
>>> @nnx.update_context('example')
... def g(m1):
...   with nnx.split_context('example') as ctx:
...     graphdef, state = ctx.split(m1)
...   graphdef_out, state_out = f(graphdef, state)
...   with nnx.merge_context('example', inner=False) as ctx:
...     return ctx.merge(graphdef_out, state_out)
...
>>> m3 = g(m1)
>>> assert m1 is m3
>>> assert m1.a == 1
>>> assert m1.ref is m1

The context can be accessed using current_update_context().

Parameters

tag – A string tag to identify the context.

flax.nnx.current_update_context(tag)[source]#

Returns the current active UpdateContext for the given tag.

flax.nnx.find_duplicates(node, /, *, only=Ellipsis)[source]#

Finds duplicate nodes or node leaves in the given node.

This function traverses the graph node and collects paths to nodes and leaves that have the same identity. It returns a list of lists, where each inner list contains paths to nodes or leaves that are duplicates.

Example:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class SharedVariables(nnx.Module):
...   def __init__(self):
...     self.a = nnx.Param(jnp.array(1.0))
...     self.b = nnx.Param(jnp.array(2.0))
...     self.c = self.b  # shared Variable
...
>>> model = SharedVariables()
>>> duplicates = nnx.find_duplicates(model)
>>> len(duplicates)
1
>>> for path in duplicates[0]:
...   print(path)
('b',)
('c',)

find_duplicates will also find duplicates nodes such as Modules that are referenced multiple times in the graph:

>>> class SharedModules(nnx.Module):
...   def __init__(self, rngs: nnx.Rngs):
...     self.a = nnx.Linear(1, 1, rngs=rngs)
...     self.b = nnx.Linear(1, 1, rngs=rngs)
...     self.c = self.a  # shared Module
...
>>> model = SharedModules(nnx.Rngs(0))
>>> for duplicate_paths in nnx.find_duplicates(model):
...   print(duplicate_paths)
[('a',), ('c',)]
Parameters
  • node – A graph node object.

  • only – A Filter to specify which nodes or leaves to consider for duplicates.

Returns

A list of lists, where each inner list contains the different paths for a for a duplicate node or leaf.

flax.nnx.pure(tree)[source]#

Returns a new tree with all Variable objects replaced with inner values.

This can be used to remove Variable metadata when its is not needed for tasks like serialization or exporting.

Example:

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> graphdef, state = nnx.split(model)
>>> jax.tree.map(jnp.shape, state)
State({
  'bias': Param(
    value=(3,)
  ),
  'kernel': Param(
    value=(2, 3)
  )
})
>>> pure_state = nnx.pure(state)
>>> jax.tree.map(jnp.shape, pure_state)
State({
  'bias': (3,),
  'kernel': (2, 3)
})
Parameters

tree – A pytree potentially containing Variable objects.

Returns

A new pytree with all Variable objects replaced with their inner values.

flax.nnx.to_refs(node, /, only=<function _array_like>)[source]#

Converts a structure of arrays to array refs.

Example:

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> node = [jnp.array(1.0), nnx.array_ref(jnp.array(2.0))]
>>> mutable_node = nnx.to_refs(node)
>>> assert nnx.is_array_ref(mutable_node[0])
>>> assert nnx.is_array_ref(mutable_node[1])

If the structure contains duplicate arrays a ValueError is raised:

>>> shared_array = jnp.array(1.0)
>>> node = [shared_array, shared_array]
>>> try:
...   nnx.to_refs(node)
... except ValueError as e:
...   print(e)
Found duplicate at paths:
  ---
  0
  1
  ---

only is a Filter that can be used to specify which arrays to convert to array refs.

>>> node = [jnp.array(1.0), jnp.array(2.0)]
>>> mutable_node = nnx.to_refs(node, only=lambda path, x: path[0] == 0)
...
>>> assert isinstance(mutable_node[0], nnx.ArrayRef)
>>> assert isinstance(mutable_node[1], jax.Array)
Parameters
  • node – A structure potentially containing arrays.

  • only – A Filter to specify which arrays to convert to array refs.

Returns

A structure with the array refs.

flax.nnx.to_arrays(node, /, *, only=<function _mutable_like>, allow_duplicates=False)[source]#

Converts a structure of array refs to regular arrays.

Example:

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> node = [nnx.array_ref(jnp.array(1.0)), jnp.array(2.0)]
>>> assert nnx.is_array_ref(node[0])
...
>>> frozen_node = nnx.to_arrays(node)
>>> assert isinstance(frozen_node[0], jax.Array)

If the structure contains duplicate array refs, a ValueError is raised:

>>> shared_array = nnx.array_ref(jnp.array(1.0))
>>> node = [shared_array, shared_array]
>>> try:
...   nnx.to_arrays(node)
... except ValueError as e:
...   print(e)
Found duplicate at paths:
  ---
  0
  1
  ---

only is a Filter that can be used to specify which array refs to freeze:

>>> node = [nnx.array_ref(jnp.array(1.0)), nnx.array_ref(jnp.array(2.0))]
>>> frozen_node = nnx.to_arrays(node, only=lambda path, x: path[0] == 0)
...
>>> assert isinstance(frozen_node[0], jax.Array)
>>> assert isinstance(frozen_node[1], nnx.ArrayRef)
Parameters
  • node – A structure potentially containing array refs.

  • only – A Filter to specify which array refs to freeze.

Returns

A structure with the frozen arrays.

flax.nnx.flatten(node, /, *, with_paths=True, ref_index=None, ref_outer_index=None)[source]#

Flattens a graph node into a (graphdef, state) pair.

Parameters
  • x – A graph node.

  • ref_index – A mapping from nodes to indexes, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to flatten a sequence of graph nodes that share references.

  • with_paths – A boolean that indicates whether to return a FlatState object that includes the paths, or just a list of the Variable’s inner values.

flax.nnx.unflatten(graphdef, state, /, *, index_ref=None, outer_index_outer_ref=None, copy_variables=False)[source]#

Unflattens a graphdef into a node with the given state.

Parameters
  • graphdef – A GraphDef instance.

  • state – A State instance.

  • index_ref – A mapping from indexes to nodes references found during the graph traversal, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to unflatten a sequence of (graphdef, state) pairs that share the same index space.

  • index_ref_cache – A mapping from indexes to existing nodes that can be reused. When an reference is reused, GraphNodeImpl.clear is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the graphdef.

  • copy_variables – If True variables in the state will be copied onto the new new structure, else variables will be shared. Default is False.