Optimizer#

class flax.nnx.optimizer.Optimizer(self, model, tx, *, wrt)#

Simple train state for the common case with a single Optax optimizer.

Example usage:

>>> import jax, jax.numpy as jnp
>>> from flax import nnx
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     return self.linear2(self.linear1(x))
...
>>> x = jax.random.normal(jax.random.key(0), (1, 2))
>>> y = jnp.ones((1, 4))
...
>>> model = Model(nnx.Rngs(0))
>>> tx = optax.adam(1e-3)
>>> optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
...
>>> loss_fn = lambda model: ((model(x) - y) ** 2).mean()
>>> loss_fn(model)
Array(2.3359997, dtype=float32)
>>> grads = nnx.grad(loss_fn)(model)
>>> optimizer.update(model, grads)
>>> loss_fn(model)
Array(2.310461, dtype=float32)
step#

An OptState Variable that tracks the step count.

tx#

An Optax gradient transformation.

opt_state#

The Optax optimizer state.

__init__(model, tx, *, wrt)#

Instantiate the class and wrap the Module and Optax gradient transformation. Instantiate the optimizer state to keep track of Variable types specified in wrt. Set the step count to 0.

Parameters
  • model – An NNX Module.

  • tx – An Optax gradient transformation.

  • wrt – optional argument to filter for which Variable’s to keep track of in the optimizer state. These should be the Variable’s that you plan on updating; i.e. this argument value should match the wrt argument passed to the nnx.grad call that will generate the gradients that will be passed into the grads argument of the update() method.

update(model, grads, /, **kwargs)#

Updates the optimizer state and model parameters given the gradients.

Example:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.count = nnx.Variable(jnp.array(0))
...
...   def __call__(self, x):
...     self.count[...] += 1
...     return self.linear(x)
...
>>> model = Model(rngs=nnx.Rngs(0))
...
>>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
>>> optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
>>> grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, nnx.Param))(
...   model, jnp.ones((1, 2)), jnp.ones((1, 3))
... )
>>> optimizer.update(model, grads)

Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.

Parameters
  • grads – the gradients derived from nnx.grad.

  • **kwargs – additional keyword arguments passed to the tx.update, to support

  • GradientTransformationExtraArgs

  • optax.scale_by_backtracking_linesearch. (such as) –