rnglib#
- class flax.nnx.Rngs(self, default=None, **rngs)[source]#
A small abstraction to manage RNG state.
Rngs
allows the creation ofRngStream
which are used to easily generate new unique random keys on demand. AnRngStream
is a wrapper around a JAX randomkey
, and acounter
. Every time a key is requested, the counter is incremented and the key is generated from the seed key and the counter by usingjax.random.fold_in
.To create an
Rngs
pass in an integer orjax.random.key
to the constructor as a keyword argument with the name of the stream. The key will be used as the starting seed for the stream, and the counter will be initialized to zero. Then call the stream to get a key:>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> rngs = nnx.Rngs(params=0, dropout=1) >>> param_key1 = rngs.params() >>> param_key2 = rngs.params() >>> dropout_key1 = rngs.dropout() >>> dropout_key2 = rngs.dropout() ... >>> assert param_key1 != dropout_key1
Trying to generate a key for a stream that was not specified during construction will result in an error being raised:
>>> rngs = nnx.Rngs(params=0, dropout=1) >>> try: ... key = rngs.unkown_stream() ... except AttributeError as e: ... print(e) No RngStream named 'unkown_stream' found in Rngs.
The
default
stream can be created by passing in a key to the constructor without specifying a stream name. When thedefault
stream is set therngs
object can be called directly to get a key, and calling streams that were not specified during construction will fallback todefault
:>>> rngs = nnx.Rngs(0, params=1) ... >>> key1 = rngs.default() # uses 'default' >>> key2 = rngs() # uses 'default' >>> key3 = rngs.params() # uses 'params' >>> key4 = rngs.dropout() # uses 'default' >>> key5 = rngs.unkown_stream() # uses 'default'
- flax.nnx.split_rngs(node=<flax.typing.Missing object>, /, *, splits, only=Ellipsis, squeeze=False)[source]#
Splits the (nested) Rng states of the given node.
- Parameters
node – the base node containing the rng states to split.
splits – an integer or tuple of integers specifying the shape of the split rng keys.
only – a Filter selecting which rng states to split.
- Returns
A SplitBackups iterable if
node
is provided, otherwise a decorator that splits the rng states of the inputs to the decorated function.
Example:
>>> from flax import nnx ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5) >>> rngs.params.key.shape, rngs.dropout.key.shape ((5,), (5,)) >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=(2, 5)) >>> rngs.params.key.shape, rngs.dropout.key.shape ((2, 5), (2, 5)) >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5, only='params') >>> rngs.params.key.shape, rngs.dropout.key.shape ((5,), ())
Once split, random state can be used with transforms like
nnx.vmap()
:>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5, only='params') ... >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) ... >>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) ... def create_model(rngs): ... return Model(rngs) ... >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape ()
split_rngs
returns a SplitBackups object that can be used to restore the original unsplit rng states usingnnx.restore_rngs()
, this is useful when you only want to split the rng states temporarily:>>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> backups = nnx.split_rngs(rngs, splits=5, only='params') >>> model = create_model(rngs) >>> nnx.restore_rngs(backups) ... >>> model.dropout.rngs.key.shape ()
SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:
>>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> with nnx.split_rngs(rngs, splits=5, only='params'): ... model = create_model(rngs) ... >>> model.dropout.rngs.key.shape () >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) ... >>> @nnx.split_rngs(splits=5, only='params') ... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) ... def create_model(rngs): ... return Model(rngs) ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape ()
- flax.nnx.fork_rngs(node=<flax.typing.Missing object>, /, *, split=None)[source]#
Forks the (nested) Rng states of the given node.
- Parameters
node – the base node containing the rng states to fork.
split – an integer, tuple of integers, or mapping specifying the shape of the forked rng keys. If a mapping, keys are filters selecting which rng states to fork with the corresponding split shape.
- Returns
A SplitBackups iterable if
node
is provided, otherwise a decorator that forks the rng states of the inputs to the decorated function.
Example:
>>> from flax import nnx ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.fork_rngs(rngs, split=5) >>> rngs.params.key.shape, rngs.dropout.key.shape ((5,), (5,)) >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.fork_rngs(rngs, split=(2, 5)) >>> rngs.params.key.shape, rngs.dropout.key.shape ((2, 5), (2, 5)) >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.fork_rngs(rngs, split={'params': 5}) >>> rngs.params.key.shape, rngs.dropout.key.shape ((5,), ())
Once forked, random state can be used with transforms like
nnx.vmap()
:>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.fork_rngs(rngs, split={'params': 5}) ... >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) ... >>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) ... def create_model(rngs): ... return Model(rngs) ... >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape ()
fork_rngs
returns a SplitBackups object that can be used to restore the original unforked rng states usingnnx.restore_rngs()
, this is useful when you only want to fork the rng states temporarily:>>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> backups = nnx.fork_rngs(rngs, split={'params': 5}) >>> model = create_model(rngs) >>> nnx.restore_rngs(backups) ... >>> model.dropout.rngs.key.shape ()
SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:
>>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> with nnx.fork_rngs(rngs, split={'params': 5}): ... model = create_model(rngs) ... >>> model.dropout.rngs.key.shape () >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) ... >>> @nnx.fork_rngs(split={'params': 5}) ... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) ... def create_model(rngs): ... return Model(rngs) ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape ()
- flax.nnx.reseed(node, /, *, policy='scalars_only', **stream_keys)[source]#
Update the keys of the specified RNG streams with new keys.
- Parameters
node – the node to reseed the RNG streams in.
policy – defines how the the new scalar key is for each RngStream is used to reseed the stream. If
'scalars_only'
is given (the default), an error is raised if the target stream key is not a scalar. If'match_shape'
is given, the new scalar key is split to match the shape of the target stream key. A callable of the form(path, scalar_key, target_shape) -> new_key
can be passed to define a custom reseeding policy.**stream_keys – a mapping of stream names to new keys. The keys can be either integers or
jax.random.key
.
Example:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) ... >>> model = Model(nnx.Rngs(params=0, dropout=42)) >>> x = jnp.ones((1, 2)) ... >>> y1 = model(x) ... >>> # reset the ``dropout`` stream key to 42 >>> nnx.reseed(model, dropout=42) >>> y2 = model(x) ... >>> jnp.allclose(y1, y2) Array(True, dtype=bool)