Optimizer#
- class flax.nnx.optimizer.Optimizer(self, model, tx, *, wrt)#
Simple train state for the common case with a single Optax optimizer.
Example usage:
>>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(2.3359997, dtype=float32) >>> grads = nnx.grad(loss_fn)(model) >>> optimizer.update(model, grads) >>> loss_fn(model) Array(2.310461, dtype=float32)
- step#
An
OptState
Variable
that tracks the step count.
- tx#
An Optax gradient transformation.
- opt_state#
The Optax optimizer state.
- __init__(model, tx, *, wrt)#
Instantiate the class and wrap the
Module
and Optax gradient transformation. Instantiate the optimizer state to keep track ofVariable
types specified inwrt
. Set the step count to 0.- Parameters
model – An NNX Module.
tx – An Optax gradient transformation.
wrt – optional argument to filter for which
Variable
’s to keep track of in the optimizer state. These should be theVariable
’s that you plan on updating; i.e. this argument value should match thewrt
argument passed to thennx.grad
call that will generate the gradients that will be passed into thegrads
argument of theupdate()
method.
- update(model, grads, /, **kwargs)#
Updates the optimizer state and model parameters given the gradients.
Example:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.count = nnx.Variable(jnp.array(0)) ... ... def __call__(self, x): ... self.count[...] += 1 ... return self.linear(x) ... >>> model = Model(rngs=nnx.Rngs(0)) ... >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) >>> grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, nnx.Param))( ... model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) >>> optimizer.update(model, grads)
Note that internally this function calls
.tx.update()
followed by a call tooptax.apply_updates()
to updateparams
andopt_state
.- Parameters
grads – the gradients derived from
nnx.grad
.**kwargs – additional keyword arguments passed to the tx.update, to support
GradientTransformationExtraArgs –
optax.scale_by_backtracking_linesearch. (such as) –