rnglib#

class flax.nnx.Rngs(self, default=None, **rngs)[source]#

A small abstraction to manage RNG state.

Rngs allows the creation of RngStream which are used to easily generate new unique random keys on demand. An RngStream is a wrapper around a JAX random key, and a counter. Every time a key is requested, the counter is incremented and the key is generated from the seed key and the counter by using jax.random.fold_in.

To create an Rngs pass in an integer or jax.random.key to the constructor as a keyword argument with the name of the stream. The key will be used as the starting seed for the stream, and the counter will be initialized to zero. Then call the stream to get a key:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> rngs = nnx.Rngs(params=0, dropout=1)

>>> param_key1 = rngs.params()
>>> param_key2 = rngs.params()
>>> dropout_key1 = rngs.dropout()
>>> dropout_key2 = rngs.dropout()
...
>>> assert param_key1 != dropout_key1

Trying to generate a key for a stream that was not specified during construction will result in an error being raised:

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> try:
...   key = rngs.unkown_stream()
... except AttributeError as e:
...   print(e)
No RngStream named 'unkown_stream' found in Rngs.

The default stream can be created by passing in a key to the constructor without specifying a stream name. When the default stream is set the rngs object can be called directly to get a key, and calling streams that were not specified during construction will fallback to default:

>>> rngs = nnx.Rngs(0, params=1)
...
>>> key1 = rngs.default()       # uses 'default'
>>> key2 = rngs()               # uses 'default'
>>> key3 = rngs.params()        # uses 'params'
>>> key4 = rngs.dropout()       # uses 'default'
>>> key5 = rngs.unkown_stream() # uses 'default'
__init__(default=None, **rngs)[source]#
Parameters
  • default – the starting seed for the default stream, defaults to None.

  • **rngs – keyword arguments specifying the starting seed for each stream. The key can be an integer or a jax.random.key.

class flax.nnx.RngStream(self, key, *, tag)[source]#
flax.nnx.split_rngs(node=<flax.typing.Missing object>, /, *, splits, only=Ellipsis, squeeze=False)[source]#

Splits the (nested) Rng states of the given node.

Parameters
  • node – the base node containing the rng states to split.

  • splits – an integer or tuple of integers specifying the shape of the split rng keys.

  • only – a Filter selecting which rng states to split.

Returns

A SplitBackups iterable if node is provided, otherwise a decorator that splits the rng states of the inputs to the decorated function.

Example:

>>> from flax import nnx
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=5)
>>> rngs.params.key.shape, rngs.dropout.key.shape
((5,), (5,))

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=(2, 5))
>>> rngs.params.key.shape, rngs.dropout.key.shape
((2, 5), (2, 5))


>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=5, only='params')
>>> rngs.params.key.shape, rngs.dropout.key.shape
((5,), ())

Once split, random state can be used with transforms like nnx.vmap():

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=5, only='params')
...
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
...
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes)
... def create_model(rngs):
...   return Model(rngs)
...
>>> model = create_model(rngs)
>>> model.dropout.rngs.key.shape
()

split_rngs returns a SplitBackups object that can be used to restore the original unsplit rng states using nnx.restore_rngs(), this is useful when you only want to split the rng states temporarily:

>>> rngs = nnx.Rngs(params=0, dropout=1)
...
>>> backups = nnx.split_rngs(rngs, splits=5, only='params')
>>> model = create_model(rngs)
>>> nnx.restore_rngs(backups)
...
>>> model.dropout.rngs.key.shape
()

SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:

>>> rngs = nnx.Rngs(params=0, dropout=1)
...
>>> with nnx.split_rngs(rngs, splits=5, only='params'):
...   model = create_model(rngs)
...
>>> model.dropout.rngs.key.shape
()

>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
...
>>> @nnx.split_rngs(splits=5, only='params')
... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes)
... def create_model(rngs):
...   return Model(rngs)
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> model = create_model(rngs)
>>> model.dropout.rngs.key.shape
()
flax.nnx.fork_rngs(node=<flax.typing.Missing object>, /, *, split=None)[source]#

Forks the (nested) Rng states of the given node.

Parameters
  • node – the base node containing the rng states to fork.

  • split – an integer, tuple of integers, or mapping specifying the shape of the forked rng keys. If a mapping, keys are filters selecting which rng states to fork with the corresponding split shape.

Returns

A SplitBackups iterable if node is provided, otherwise a decorator that forks the rng states of the inputs to the decorated function.

Example:

>>> from flax import nnx
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split=5)
>>> rngs.params.key.shape, rngs.dropout.key.shape
((5,), (5,))

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split=(2, 5))
>>> rngs.params.key.shape, rngs.dropout.key.shape
((2, 5), (2, 5))


>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split={'params': 5})
>>> rngs.params.key.shape, rngs.dropout.key.shape
((5,), ())

Once forked, random state can be used with transforms like nnx.vmap():

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split={'params': 5})
...
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
...
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes)
... def create_model(rngs):
...   return Model(rngs)
...
>>> model = create_model(rngs)
>>> model.dropout.rngs.key.shape
()

fork_rngs returns a SplitBackups object that can be used to restore the original unforked rng states using nnx.restore_rngs(), this is useful when you only want to fork the rng states temporarily:

>>> rngs = nnx.Rngs(params=0, dropout=1)
...
>>> backups = nnx.fork_rngs(rngs, split={'params': 5})
>>> model = create_model(rngs)
>>> nnx.restore_rngs(backups)
...
>>> model.dropout.rngs.key.shape
()

SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:

>>> rngs = nnx.Rngs(params=0, dropout=1)
...
>>> with nnx.fork_rngs(rngs, split={'params': 5}):
...   model = create_model(rngs)
...
>>> model.dropout.rngs.key.shape
()

>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
...
>>> @nnx.fork_rngs(split={'params': 5})
... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes)
... def create_model(rngs):
...   return Model(rngs)
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> model = create_model(rngs)
>>> model.dropout.rngs.key.shape
()
flax.nnx.reseed(node, /, *, policy='scalars_only', **stream_keys)[source]#

Update the keys of the specified RNG streams with new keys.

Parameters
  • node – the node to reseed the RNG streams in.

  • policy – defines how the the new scalar key is for each RngStream is used to reseed the stream. If 'scalars_only' is given (the default), an error is raised if the target stream key is not a scalar. If 'match_shape' is given, the new scalar key is split to match the shape of the target stream key. A callable of the form (path, scalar_key, target_shape) -> new_key can be passed to define a custom reseeding policy.

  • **stream_keys – a mapping of stream names to new keys. The keys can be either integers or jax.random.key.

Example:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))
...
>>> model = Model(nnx.Rngs(params=0, dropout=42))
>>> x = jnp.ones((1, 2))
...
>>> y1 = model(x)
...
>>> # reset the ``dropout`` stream key to 42
>>> nnx.reseed(model, dropout=42)
>>> y2 = model(x)
...
>>> jnp.allclose(y1, y2)
Array(True, dtype=bool)