object#

class flax.nnx.Pytree(*args, **kwargs)[source]#

Base class for all NNX objects.

class flax.nnx.Object(self, /, *args, **kwargs)[source]#

Base class for NNX objects that are not pytrees.

flax.nnx.data(value, /)[source]#

Annotates a an attribute as pytree data.

The return value from data must be directly assigned to an Object attribute which will be registered as a pytree data attribute.

Example:

from flax import nnx
import jax

class Foo(nnx.Pytree):
  def __init__(self):
    self.data_attr = nnx.data(42)  # pytree data
    self.static_attr = "hello"     # static attribute

foo = Foo()

assert jax.tree.leaves(foo) == [42]
Parameters

value – The value to annotate as data.

Returns

A value which will register the attribute as data on assignment.

flax.nnx.Data#

Data marks attributes of a class as pytree data using type annotations.

Data annotations must be used at the class level and will apply to all instances. The usage of Data is recommended when type annotations are used already present or required e.g. for dataclasses.

Example:

from flax import nnx
import jax
import dataclasses

@dataclasses.dataclass
class Foo(nnx.Pytree):
  a: nnx.Data[int]  # Annotates `a` as pytree data
  b: str            # `b` is not pytree data

foo = Foo(a=42, b='hello')

assert jax.tree.leaves(foo) == [42]

alias of A[A]

flax.nnx.static(value, /)[source]#

Annotates a an attribute as static.

The return value from static must be directly assigned to an Object attribute which will be registered as static attribute.

Example:

from flax import nnx

class Foo(nnx.Pytree):
  def __init__(self, a, b):
    self.a = nnx.static(a)  # pytree metadata
    self.b = nnx.data(b)    # pytree data

foo = Foo("one", "two")

assert jax.tree.leaves(foo) == ["two"]

By default nnx.Pytree will …

flax.nnx.Static#

Static marks attributes of a class as static using type annotations. Static annotations must be used at the class level and will apply to all instances. The usage of Static is recommended when type annotations are used already present or required e.g. for dataclasses.

alias of A[A]

flax.nnx.is_data(value, /)[source]#

Checks if a value is a registered data type.

This function checks a the value is registered data type, which means it is automatically recognized as data when assigned a nnx.Pytree attribute.

Data types are: - jax.Arrays - np.ndarrays - ArrayRefs - Variables (Param, BatchStat, RngState, etc.) - All graph nodes (Object, Module, Rngs, etc.) - Any type registered with nnx.register_data_type - Any pytree that contains at least one node or leaf element of the above

Example:

>>> from flax import nnx
>>> import jax.numpy as jnp
... # ------ DATA ------------
>>> assert nnx.is_data( jnp.array(0) )                      # Arrays
>>> assert nnx.is_data( nnx.Param(1) )                      # Variables
>>> assert nnx.is_data( nnx.Rngs(2) )                       # nnx.Pytrees
>>> assert nnx.is_data( nnx.Linear(1, 1,rngs=nnx.Rngs(0)) ) # Modules
... # ------ STATIC ------------
>>> assert not nnx.is_data( 'hello' )                       # strings, arbitrary objects
>>> assert not nnx.is_data( 42 )                            # int, float, bool, complex, etc.
>>> assert not nnx.is_data( [1, 2.0, 3j, jnp.array(1)] )    # list, dict, tuple, pytrees
Parameters

value – The value to check.

Returns

A string representing the attribute status.

flax.nnx.register_data_type(type_, /)[source]#

Registers a type as pytree data type recognized by Object.

Custom types registered as data will be automatically recognized as data attributes when assigned to an Object attribute. This means that values of this type do not need to be wrapped in nnx.data(…) for Object to mark the attribute its being assigned to as data.

Example:

from flax import nnx
from dataclasses import dataclass

@dataclass(frozen=True)
class MyType:
  value: int

nnx.register_data_type(MyType)

class Foo(nnx.Pytree):
  def __init__(self, a):
    self.a = MyType(a)  # Automatically registered as data
    self.b = "hello"     # str not registered as data

foo = Foo(42)

assert nnx.is_data(foo.a)  # True
assert jax.tree.leaves(foo) == [MyType(value=42)]
flax.nnx.check_pytree(pytree)[source]#

Checks if a pytree is valid.