helpers#
- class flax.nnx.Sequential(self, *fns)[source]#
A Module that applies a sequence of callables.
This class provides a way to store and manipulate a sequence of callables (e.g. layers, activation functions) and apply them in order.
Example:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> rngs = nnx.Rngs(0) >>> model = nnx.Sequential( ... nnx.Linear(1, 4, rngs=rngs), # data ... nnx.relu, # static ... nnx.Linear(4, 2, rngs=rngs), # data ... ) >>> x = jnp.ones((1, 1)) >>> y = model(x) >>> y.shape (1, 2)
- class flax.nnx.List(self, it=None, /)[source]#
A Module that implements a mutable sequence.
This class provides a way to store and manipulate a sequence of values contained a mixed set of data (e.g. Array, Variables, Modules) and static (e.g. functions, strings) types.
Example
>>> from flax import nnx ... >>> rngs = nnx.Rngs(0) >>> layers = nnx.List([ ... nnx.Linear(1, 32, rngs=rngs), # data ... nnx.relu, # static ... nnx.Linear(32, 1, rngs=rngs), # data ... ])
- class flax.nnx.Dict(self, *args, **kwargs)[source]#
A Module that implements a mutable mapping.
This class provides a way to store and manipulate a mapping of keys to values contained a mixed set of data (e.g. Array, Variables, Modules) and static (e.g. functions, strings) types.
Example
>>> from flax import nnx ... >>> rngs = nnx.Rngs(0) >>> layers = nnx.Dict({ ... 'kernel1': nnx.Linear(1, 32, rngs=rngs), # data ... 'activation1': nnx.relu, # static ... 'kernel2': nnx.Linear(32, 1, rngs=rngs), # data ... })