helpers#

class flax.nnx.Sequential(self, *fns)[source]#

A Module that applies a sequence of callables.

This class provides a way to store and manipulate a sequence of callables (e.g. layers, activation functions) and apply them in order.

Example:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> rngs = nnx.Rngs(0)
>>> model = nnx.Sequential(
...   nnx.Linear(1, 4, rngs=rngs),  # data
...   nnx.relu,                     # static
...   nnx.Linear(4, 2, rngs=rngs),  # data
... )
>>> x = jnp.ones((1, 1))
>>> y = model(x)
>>> y.shape
(1, 2)
class flax.nnx.List(self, it=None, /)[source]#

A Module that implements a mutable sequence.

This class provides a way to store and manipulate a sequence of values contained a mixed set of data (e.g. Array, Variables, Modules) and static (e.g. functions, strings) types.

Example

>>> from flax import nnx
...
>>> rngs = nnx.Rngs(0)
>>> layers = nnx.List([
...     nnx.Linear(1, 32, rngs=rngs),  # data
...     nnx.relu,                      # static
...     nnx.Linear(32, 1, rngs=rngs),  # data
... ])
append(value)[source]#

S.append(value) – append value to the end of the sequence

insert(index, value)[source]#

S.insert(index, value) – insert value before index

class flax.nnx.Dict(self, *args, **kwargs)[source]#

A Module that implements a mutable mapping.

This class provides a way to store and manipulate a mapping of keys to values contained a mixed set of data (e.g. Array, Variables, Modules) and static (e.g. functions, strings) types.

Example

>>> from flax import nnx
...
>>> rngs = nnx.Rngs(0)
>>> layers = nnx.Dict({
...     'kernel1': nnx.Linear(1, 32, rngs=rngs),   # data
...     'activation1': nnx.relu,                   # static
...     'kernel2': nnx.Linear(32, 1, rngs=rngs),   # data
... })
class flax.nnx.TrainState(graphdef: 'graph.GraphDef[M]', params: 'State', opt_state: 'optax.OptState', step: 'jax.Array', tx: 'optax.GradientTransformation')[source]#
replace(**updates)#

Returns a new object replacing the specified fields with new values.