Flax basics#
Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.
To begin, install Flax with pip
and import necessary dependencies:
# ! pip install -U flax
from flax import nnx
import jax
import jax.numpy as jnp
The Flax NNX Module system#
The main difference between the Flax Module
and other Module systems in Flax Linen or Haiku is that in NNX everything is explicit. This means, among other things, that the NNX Module itself holds the state (such as parameters) directly, the PRNG state is threaded by the user, and all shape information must be provided on initialization (no shape inference).
Let’s begin by creating a Linear Module
. As shown next, dynamic state is usually stored in Param
s, and static state (all types not handled by NNX) such as integers or strings are stored directly. Attributes of type jax.Array
and numpy.ndarray
are also treated as dynamic states, although storing them inside Variables, such as Param, is preferred. Also the Rngs
object can be used to get new unique keys based on a root PRNG key passed to the constructor.
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(rngs.params.uniform((din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
Also note that the inner values of Variable
s can be accessed using the value
property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above).
To initialize a Flax Module
, you just call the constructor, and all the parameters of a Module are usually created eagerly. Since Modules hold their own state methods, you can call them directly without the need for a separate apply method.
This can be very convenient for debugging, allowing you to directly inspect the entire structure of the model.
model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))
print(y)
nnx.display(model)
[[1.5643291 0.94782424 0.37971854 1.0724319 0.22112393]]
/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook <function use_autovisualizer_if_present at 0x117717600>:
Traceback (most recent call last):
File "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/renderers.py", line 225, in _render_subtree
postprocessed_result = hook(
^^^^^
File "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py", line 47, in use_autovisualizer_if_present
result = autoviz(node, path)
^^^^^^^^^^^^^^^^^^^
File "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py", line 306, in __call__
jax.sharding.PositionalSharding
File "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(message)
AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0
warnings.warn(
The above visualization by nnx.display
is generated using the awesome
Treescope library.
Stateful computation#
Implementing layers, such as BatchNorm
, requires performing state updates during a forward pass. In Flax NNX, you just need to create a Variable
and update its .value
during the forward pass.
class Count(nnx.Variable): pass
class Counter(nnx.Module):
def __init__(self):
self.count = Count(jnp.array(0))
def __call__(self):
self.count += 1
counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)
Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms to handle them, as demonstrated in later sections of this guide.
Nested Modules#
Flax Module
s can be used to compose other Modules in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a list
, dict
, tuple
, and so on.
The example below shows how to define a simple MLP
by subclassing Module
. The model consists of two Linear
layers, a Dropout
layer, and a BatchNorm
layer:
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = Linear(dmid, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
model = MLP(2, 16, 5, rngs=nnx.Rngs(0))
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
In Flax, Dropout
is a stateful module that stores an Rngs
object, so that it can generate new masks during the forward pass without the need for the user to pass a new key each time.
Model surgery#
Flax Module
s are mutable by default. This means that their structure can be changed at any time, which makes model surgery quite easy, as any sub-Module attribute can be replaced with anything else, such as new Modules, existing shared Modules, Modules of different types, and so on. Moreover, Variable
s can also be modified or replaced/shared.
The following example shows how to replace the Linear
layers in the MLP
model from the previous example with LoraLinear
layers:
class LoraParam(nnx.Param): pass
class LoraLinear(nnx.Module):
def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
self.linear = linear
self.A = LoraParam(rngs.normal((linear.din, rank)))
self.B = LoraParam(rngs.normal((rank, linear.dout)))
def __call__(self, x: jax.Array):
return self.linear(x) + x @ self.A @ self.B
rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)
# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
Flax transformations#
Flax NNX transformations (transforms) extend JAX transforms to support Module
s and other objects. They serve as supersets of their equivalent JAX counterparts with the addition of being aware of the object’s state and providing additional APIs to transform it.
One of the main features of Flax Transforms is the preservation of reference semantics, meaning that any mutation of the object graph that occurs inside the transform is propagated outside as long as it is legal within the transform rules. In practice this means that Flax programs can be expressed using imperative code, highly simplifying the user experience.
In the following example, you define a train_step
function that takes a MLP
model, an Optimizer
, and a batch of data, and returns the loss for that step. The loss and the gradients are computed using the nnx.value_and_grad
transform over the loss_fn
. The gradients are passed to the optimizer’s update
method to update the model’s parameters.
import optax
# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
@nnx.jit # Automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads) # In place updates.
return loss
x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)
print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000602, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)
There are two things happening in this example that are worth mentioning:
The updates to each of the
BatchNorm
andDropout
layer’s state is automatically propagated from withinloss_fn
totrain_step
all the way to themodel
reference outside.The
optimizer
holds a mutable reference to the model - this relationship is preserved inside the train_step function making it possible to update the model’s parameters using the optimizer alone.
Note
nnx.jit
has performance overhead for small models, check the Performance Considerations guide for more information.
Scan over layers#
The next example uses Flax nnx.vmap
to create a stack of multiple MLP layers and nnx.scan
to iteratively apply each layer of the stack to the input.
In the code below notice the following:
The custom
create_model
function takes in a key and returns anMLP
object, since you create five keys and usennx.vmap
overcreate_model
a stack of 5MLP
objects is created.The
nnx.scan
is used to iteratively apply eachMLP
in the stack to the inputx
.The nnx.scan (consciously) deviates from
jax.lax.scan
and instead mimics nnx.vmap, which is more expressive. nnx.scan allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.State
updates for theBatchNorm
andDropout
layers are automatically propagated by nnx.scan.
@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
return MLP(10, 32, 10, rngs=nnx.Rngs(key))
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
x = model(x)
return x
x = jnp.ones((3, 10))
y = forward(model, x)
print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)
How do Flax NNX transforms achieve this? To understand how Flax NNX objects interact with JAX transforms, the next section explains the Flax NNX Functional API.
The Flax Functional API#
The Flax NNX Functional API establishes a clear boundary between reference/object semantics and value/pytree semantics. It also allows the same amount of fine-grained control over the state that Flax Linen and Haiku users are used to. The Flax NNX Functional API consists of three basic methods: nnx.split
, nnx.merge
, and nnx.update
.
Below is an example of of StatefulLinear
Module
that uses the Functional API. It contains:
Some
Param
Variables; andA custom
Count
Variable type, which is used to track the integer scalar state that increases on every forward pass.
class Count(nnx.Variable): pass
class StatefulLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(rngs.uniform((din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(jnp.array(0, dtype=jnp.uint32))
def __call__(self, x: jax.Array):
self.count += 1
return x @ self.w + self.b
model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))
nnx.display(model)
State and GraphDef#
A Flax Module
can be decomposed into State
and GraphDef
using the nnx.split
function:
State
is aMapping
from strings toVariable
s or nestedState
s.GraphDef
contains all the static information needed to reconstruct aModule
graph, it is analogous to JAX’sPyTreeDef
.
graphdef, state = nnx.split(model)
nnx.display(graphdef, state)
Split, merge, and update#
Flax’s nnx.merge
is the reverse of nnx.split
. It takes the GraphDef
+ State
and reconstructs the Module
. The example below demonstrates this as follows:
By using
nnx.split
andnnx.merge
in sequence anyModule
can be lifted to be used in any JAX transform.nnx.update
can update an object in place with the content of a givenState
.This pattern is used to propagate the state from a transform back to the source object outside.
print(f'{model.count.value = }')
# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)
@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
# 2. Use `nnx.merge` to create a new model inside the JAX transformation.
model = nnx.merge(graphdef, state)
# 3. Call the `nnx.Module`
y = model(x)
# 4. Use `nnx.split` to propagate `nnx.State` updates.
_, state = nnx.split(model)
return y, state
y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)
print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)
The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries.
Why aren’t Modules just pytrees? The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two Module
s that have a shared Module through a JAX boundary, you will silently lose that sharing. Flax’s Functional API makes this behavior explicit, and thus it is much easier to reason about.
Fine-grained State control#
Experienced Flax Linen or Haiku API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This is a common occurrence when interacting with JAX transforms.
For example:
Not every model state can or should be differentiated when interacting with
jax.grad
.Or, sometimes, there is a need to specify what part of the model’s state is a carry and what part is not when using
jax.lax.scan
.
To address this, the Flax NNX API has nnx.split
, which allows you to pass one or more Filter
s to partition the Variable
s into mutually exclusive State
s. Flax NNx uses Filter
create State
groups in APIs (such as nnx.split
, nnx.state
, and many of NNX transforms).
The example below shows the most common Filter
s:
# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
nnx.display(params, counts)
Note: Filter
s must be exhaustive, if a value is not matched an error will be raised.
As expected, the nnx.merge
and nnx.update
methods naturally consume multiple State
s:
# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)