# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations
import contextlib
import dataclasses
import functools
from functools import partial
import threading
import typing as tp
from typing import Any
from flax import config
import jax
import treescope # type: ignore[import-untyped]
from flax import errors
from flax.nnx import filterlib, reprlib, tracers, visualization
from flax.typing import Missing, PathParts, SizeBytes
import jax.tree_util as jtu
import jax.numpy as jnp
from jax._src.state.types import AbstractRef
A = tp.TypeVar('A')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
V = tp.TypeVar('V', bound='Variable[Any]')
GetValueHook = tp.Callable[['Variable[A]', A], A]
SetValueHook = tp.Callable[['Variable[A]', A], A]
CreateValueHook = tp.Callable[['Variable[A]', A], A]
AxisName = str
AxisIndex = int
AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
if hasattr(jax, 'array_ref') and hasattr(jax, 'ArrayRef'):
from jax import array_ref # type: ignore[import-untyped]
from jax import ArrayRef # type: ignore[import-untyped]
elif not tp.TYPE_CHECKING:
from jax._src.core import mutable_array, MutableArray
array_ref = mutable_array
ArrayRef = MutableArray
# temp workaround future proof correct repr
MutableArray.__repr__ = lambda self: 'ArrayRef' + repr(self._buf)[5:] # type: ignore[method-assign]
@dataclasses.dataclass
class VariableContext(threading.local):
mutable_variable_stack: list[bool] = dataclasses.field(default_factory=list)
VARIABLE_CONTEXT = VariableContext()
def using_refs() -> bool:
"""Returns whether Variables are using ArrayRefs by default.
Example::
>>> from flax import nnx
...
>>> nnx.using_refs()
False
>>> nnx.use_refs(True)
<...>
>>> nnx.using_refs()
True
Returns:
A boolean indicating if Variables are using ArrayRefs by default.
"""
if VARIABLE_CONTEXT.mutable_variable_stack:
return VARIABLE_CONTEXT.mutable_variable_stack[-1]
else:
return config.flax_array_ref
def use_refs(value: bool, /):
"""Sets whether Variables should use ArrayRefs by default or not.
Example usage::
>>> from flax import nnx
>>> # Use ArrayRefs by default
>>> nnx.use_refs(True)
<...>
>>> # Variable will now use ArrayRefs
>>> v = nnx.Variable(jax.numpy.ones((2, 3)))
>>> v.has_ref
True
>>> v.raw_value
ArrayRef(...)
It can also be used as a context manager to temporarily
change the default behavior for a block of code::
>>> nnx.use_refs(False)
<...>
>>> with nnx.use_refs(True):
... v = nnx.Variable(jax.numpy.ones((2, 3)))
... v.has_ref
True
>>> # it will reset outside
>>> v = nnx.Variable(jax.numpy.ones((2, 3)))
>>> v.has_ref
False
Args:
value: A boolean indicating if Variables should use ArrayRefs by default.
Returns:
A context manager that resets the context to the previous value.
"""
# prev_value = VARIABLE_CONTEXT.mutable_variable_stack[-1] if VARIABLE_CONTEXT.mutable_variable_stack else None
# VARIABLE_CONTEXT.mutable_variable_stack.append(value)
if VARIABLE_CONTEXT.mutable_variable_stack:
prev_value = VARIABLE_CONTEXT.mutable_variable_stack[-1]
VARIABLE_CONTEXT.mutable_variable_stack[-1] = value
else:
prev_value = None
VARIABLE_CONTEXT.mutable_variable_stack.append(value)
return _clean_mutable_arrays_context(prev_value)
@contextlib.contextmanager
def _clean_mutable_arrays_context(prev_value: bool | None):
if prev_value is not None:
VARIABLE_CONTEXT.mutable_variable_stack.insert(-1, prev_value)
try:
yield
finally:
VARIABLE_CONTEXT.mutable_variable_stack.pop()
def is_array_ref(x) -> tp.TypeGuard[ArrayRef]:
return isinstance(x, jax.Array | AbstractRef | ArrayRef) and isinstance(
jax.typeof(x), AbstractRef | ArrayRef
)
[docs]class Variable(tp.Generic[A], reprlib.Representable):
"""The base class for all ``Variable`` types. Create custom ``Variable``
types by subclassing this class. Numerous NNX graph functions can filter
for specific ``Variable`` types, for example, :func:`split`, :func:`state`,
:func:`pop`, and :func:`State.filter`.
Example usage::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> class CustomVariable(nnx.Variable):
... pass
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... self.custom_variable = CustomVariable(jnp.ones((1, 3)))
... def __call__(self, x):
... return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))
>>> linear_variables = nnx.state(model, nnx.Param)
>>> jax.tree.map(jnp.shape, linear_variables)
State({
'linear': {
'bias': Param(
value=(3,)
),
'kernel': Param(
value=(2, 3)
)
}
})
>>> custom_variable = nnx.state(model, CustomVariable)
>>> jax.tree.map(jnp.shape, custom_variable)
State({
'custom_variable': CustomVariable(
value=(1, 3)
)
})
>>> variables = nnx.state(model)
>>> jax.tree.map(jnp.shape, variables)
State({
'custom_variable': CustomVariable(
value=(1, 3)
),
'linear': {
'bias': Param(
value=(3,)
),
'kernel': Param(
value=(2, 3)
)
}
})
"""
__slots__ = ('raw_value', '_trace_state', '_var_metadata')
raw_value: A
_trace_state: tracers.TraceState
_var_metadata: dict[str, tp.Any]
def __init__(
self,
value: tp.Union[A, VariableMetadata[A]],
*,
use_ref: bool | None = None,
**metadata: tp.Any,
):
if use_ref is None:
use_ref = using_refs()
var_t = type(self)
object.__setattr__(self, '_trace_state', tracers.TraceState())
if isinstance(value, VariableMetadata):
metadata.update(value.metadata)
value = tp.cast(A, value.raw_value)
if use_ref:
if is_array_ref(value):
_value = tp.cast(A, value)
else:
_value = array_ref(jnp.asarray(value)) # type: ignore[assignment] # type: ignore[assignment]
else:
_value = value
object.__setattr__(self, 'raw_value', _value)
if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata:
metadata['on_get_value'] = var_t.on_get_value
if hasattr(var_t, 'on_set_value') and 'on_set_value' not in metadata:
metadata['on_set_value'] = var_t.on_set_value
if hasattr(var_t, 'on_create_value') and 'on_create_value' not in metadata:
metadata['on_create_value'] = var_t.on_create_value
if hasattr(var_t, 'on_add_axis') and 'on_add_axis' not in metadata:
metadata['on_add_axis'] = var_t.on_add_axis
if hasattr(var_t, 'on_remove_axis') and 'on_remove_axis' not in metadata:
metadata['on_remove_axis'] = var_t.on_remove_axis
if 'sharding' in metadata:
metadata['sharding_names'] = metadata.pop('sharding')
object.__setattr__(self, '_var_metadata', metadata)
# run create_value hooks
object.__setattr__(self, 'raw_value', self.create_value(self.raw_value))
def __getattr__(self, name: str) -> tp.Any:
if name in object.__getattribute__(self, '_var_metadata'):
return self._var_metadata[name]
return getattr(self.raw_value, name)
def __setattr__(self, name: str, value: tp.Any):
if not self._trace_state.is_valid() and (
name != 'value' or not self.has_ref
):
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if (
name == 'value'
or name == 'raw_value'
or name == '_var_metadata'
or name == '_trace_state'
):
object.__setattr__(self, name, value)
else:
self._var_metadata[name] = value
def __delattr__(self, name: str):
if not self._trace_state.is_valid():
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if (
name == 'value'
or name == 'raw_value'
or name == '_var_metadata'
or name == '_trace_state'
):
object.__delattr__(self, name)
else:
del self._var_metadata[name]
# NOTE(cgarciae): adding this for backward compatibility with VariableState
@property
def type(self):
"""The type of the variable."""
import warnings
warnings.warn(
"'.type' is deprecated, use 'type(variable)' instead.",
DeprecationWarning,
stacklevel=2,
)
return type(self)
@property
def has_ref(self) -> bool:
return is_array_ref(self.raw_value)
def get_metadata(self):
return self._var_metadata
def copy_from(self, other: Variable[A]) -> None:
if type(self) is not type(other):
raise ValueError(
f'Cannot copy from incompatible container, '
f'expected {type(self).__name__}, got {type(other).__name__}'
)
if self is other:
return
self.raw_value = other.raw_value
self._var_metadata.clear()
self._var_metadata.update(other.get_metadata())
def update_from_state(self, variable_state: Variable[A]):
if self.has_ref and (
variable_state.has_ref or isinstance(variable_state.raw_value, jax.Array)
):
self.raw_value[...] = variable_state.raw_value[...] # type: ignore
else:
object.__setattr__(self, 'raw_value', variable_state.raw_value)
if self._var_metadata != variable_state._var_metadata:
object.__setattr__(
self, '_var_metadata', variable_state._var_metadata.copy()
)
@property
def value(self) -> A:
value = self.raw_value
if is_array_ref(value):
value = value[...]
if 'on_get_value' in self._var_metadata:
value = self._var_metadata['on_get_value'](self, value)
return value
@value.setter
def value(self, value: A):
if isinstance(value, Variable):
raise ValueError(
'Cannot set value to a Variable, use `copy_from` method instead'
)
if 'on_set_value' in self._var_metadata:
value = self._var_metadata['on_set_value'](self, value)
if self.has_ref:
self.raw_value[...] = value # type: ignore
else:
object.__setattr__(self, 'raw_value', value)
def create_value(self, value: A):
if 'on_create_value' in self._var_metadata:
value = self._var_metadata['on_create_value'](self, value)
return value
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_add_axis' in self._var_metadata:
self._var_metadata['on_add_axis'](self, axis_index, axis_name)
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
@tp.overload
def replace(self, value: B, **kwargs) -> Variable[B]:
...
@tp.overload
def replace(self, **kwargs) -> Variable[A]:
...
def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]:
if value is not Missing:
kwargs['raw_value'] = value
# rename `value` to `raw_value`
if 'value' in kwargs:
kwargs['raw_value'] = kwargs.pop('value')
# return `value` if it is a Variable
if 'raw_value' in kwargs and isinstance(
value := kwargs['raw_value'], Variable
):
# remove value from kwargs
kwargs.pop('raw_value')
if type(self) is not type(value):
raise ValueError(
'Cannot replace value from incompatible container, '
f'expected {type(self).__name__}, got {type(value).__name__}'
)
# if kwargs aren't empty, recursively call replace
# else return variable value
if kwargs:
return value.replace(**kwargs)
else:
return value
# get and update attributes
# return new instance with updated attributes
obj = object.__new__(type(self))
object.__setattr__(obj, '_trace_state', self._trace_state)
object.__setattr__(obj, 'raw_value', kwargs.pop('raw_value'))
object.__setattr__(obj, '_var_metadata', self.get_metadata() | kwargs)
return obj
@classmethod
def from_metadata(cls, value: A, attributes: dict[str, tp.Any]):
obj = object.__new__(cls)
object.__setattr__(obj, '_trace_state', tracers.TraceState())
object.__setattr__(obj, 'raw_value', value)
object.__setattr__(obj, '_var_metadata', attributes)
return obj
def copy(self: Variable[A]) -> Variable[A]:
obj = object.__new__(type(self))
object.__setattr__(obj, '_trace_state', tracers.TraceState())
object.__setattr__(obj, 'raw_value', self.raw_value)
object.__setattr__(obj, '_var_metadata', self.get_metadata().copy())
return obj
to_state = copy
def __nnx_repr__(self):
stats = SizeBytes.from_any(self.raw_value)
if stats:
comment = f' # {stats}'
else:
comment = ''
yield reprlib.Object(type=type(self).__name__, comment=comment)
yield reprlib.Attr('value', self.raw_value)
for name, value in self._var_metadata.items():
yield reprlib.Attr(name, value)
def __treescope_repr__(self, path, subtree_renderer):
size_bytes = SizeBytes.from_any(self.value)
if size_bytes:
stats_repr = f' # {size_bytes}'
first_line_annotation = treescope.rendering_parts.comment_color(
treescope.rendering_parts.text(f'{stats_repr}')
)
else:
first_line_annotation = None
children = {'value': self.raw_value, **self._var_metadata}
return visualization.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
first_line_annotation=first_line_annotation,
)
# hooks API
if tp.TYPE_CHECKING:
def on_get_value(self, value: A) -> A: ...
def on_set_value(self, value: A) -> A: ...
def on_create_value(self, value: A) -> A: ...
def on_add_axis(
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...
def on_remove_axis(
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...
def __jax_array__(self):
return self.value
# pickle support
def __getstate__(self):
return {
'raw_value': self.raw_value,
'_trace_state': self._trace_state,
'_var_metadata': self._var_metadata,
}
def __setstate__(self, state):
object.__setattr__(self, 'raw_value', state['raw_value'])
object.__setattr__(self, '_trace_state', state['_trace_state'])
object.__setattr__(self, '_var_metadata', state['_var_metadata'])
# --------------------------------------------
# proxy methods
# --------------------------------------------
def __getitem__(self, key) -> jax.Array:
return self.value[key] # type: ignore
def __setitem__(self, key, value) -> None:
if not self.has_ref and not self._trace_state.is_valid():
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if self.has_ref:
self.raw_value[key] = value # type: ignore
elif key == ...:
self.value = value
elif isinstance(self.raw_value, jax.Array):
self.raw_value = self.raw_value.at[key].set(value) # type: ignore
else:
self.raw_value[key] = value # type: ignore
def __call__(self, *args, **kwargs) -> tp.Any:
return self.value(*args, **kwargs) # type: ignore
def __len__(self) -> int:
return len(self.value) # type: ignore
def __iter__(self) -> tp.Iterator:
return iter(self.value) # type: ignore
def __contains__(self, item) -> bool:
return item in self.value # type: ignore
def __add__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__add__(other) # type: ignore
def __sub__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__sub__(other) # type: ignore
def __mul__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__mul__(other) # type: ignore
def __matmul__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__matmul__(other) # type: ignore
def __truediv__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__truediv__(other) # type: ignore
def __floordiv__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__floordiv__(other) # type: ignore
def __mod__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__mod__(other) # type: ignore
def __divmod__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__divmod__(other) # type: ignore
def __pow__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__pow__(other) # type: ignore
def __lshift__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__lshift__(other) # type: ignore
def __rshift__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rshift__(other) # type: ignore
def __and__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__and__(other) # type: ignore
def __xor__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__xor__(other) # type: ignore
def __or__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__or__(other) # type: ignore
def __radd__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__radd__(other) # type: ignore
def __rsub__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rsub__(other) # type: ignore
def __rmul__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rmul__(other) # type: ignore
def __rmatmul__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rmatmul__(other) # type: ignore
def __rtruediv__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rtruediv__(other) # type: ignore
def __rfloordiv__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rfloordiv__(other) # type: ignore
def __rmod__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rmod__(other) # type: ignore
def __rdivmod__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rdivmod__(other) # type: ignore
def __rpow__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rpow__(other) # type: ignore
def __rlshift__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rlshift__(other) # type: ignore
def __rrshift__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rrshift__(other) # type: ignore
def __rand__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rand__(other) # type: ignore
def __rxor__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rxor__(other) # type: ignore
def __ror__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__ror__(other) # type: ignore
def __iadd__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__iadd__'):
value.__iadd__(other)
else:
self.value = value.__add__(other)
return self
def __isub__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__isub__'):
value.__isub__(other)
else:
self.value = value.__sub__(other)
return self
def __imul__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__imul__'):
value.__imul__(other)
else:
self.value = value.__mul__(other)
return self
def __imatmul__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__imatmul__'):
value.__imatmul__(other)
else:
self.value = value.__matmul__(other)
return self
def __itruediv__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__itruediv__'):
value.__itruediv__(other)
else:
self.value = value.__truediv__(other)
return self
def __ifloordiv__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ifloordiv__'):
value.__ifloordiv__(other)
else:
self.value = value.__floordiv__(other)
return self
def __imod__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__imod__'):
value.__imod__(other)
else:
self.value = value.__mod__(other)
return self
def __ipow__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ipow__'):
value.__ipow__(other)
else:
self.value = value.__pow__(other)
return self
def __ilshift__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ilshift__'):
value.__ilshift__(other)
else:
self.value = value.__lshift__(other)
return self
def __irshift__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__irshift__'):
value.__irshift__(other)
else:
self.value = value.__rshift__(other)
return self
def __iand__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__iand__'):
value.__iand__(other)
else:
self.value = value.__and__(other)
return self
def __ixor__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ixor__'):
value.__ixor__(other)
else:
self.value = value.__xor__(other)
return self
def __ior__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ior__'):
value.__ior__(other)
else:
self.value = value.__or__(other)
return self
def __neg__(self) -> A:
return self.value.__neg__() # type: ignore
def __pos__(self) -> A:
return self.value.__pos__() # type: ignore
def __abs__(self) -> A:
return self.value.__abs__() # type: ignore
def __invert__(self) -> A:
return self.value.__invert__() # type: ignore
def __complex__(self) -> A:
return self.value.__complex__() # type: ignore
def __int__(self) -> A:
return self.value.__int__() # type: ignore
def __float__(self) -> A:
return self.value.__float__() # type: ignore
def __index__(self) -> A:
return self.value.__index__() # type: ignore
def __round__(self, ndigits: int) -> A:
return self.value.__round__(ndigits) # type: ignore
def __trunc__(self) -> A:
return self.value.__trunc__() # type: ignore
def __floor__(self) -> A:
return self.value.__floor__() # type: ignore
def __ceil__(self) -> A:
return self.value.__ceil__() # type: ignore
# --------------------------------------------
def __init_subclass__(cls) -> None:
super().__init_subclass__()
jax.tree_util.register_pytree_with_keys(
cls,
flatten_with_keys=_variable_flatten_with_keys,
unflatten_func=partial(_variable_unflatten, cls), # type: ignore
flatten_func=_variable_flatten,
)
def _variable_flatten_with_keys(x: Variable[tp.Any]):
metadata = tuple(x.get_metadata().items())
node = (jtu.GetAttrKey('value'), x.raw_value)
return (node,), metadata
def _variable_flatten(x: Variable[tp.Any]):
metadata = tuple(x.get_metadata().items())
return (x.raw_value,), metadata
def _variable_unflatten(
cls: type[Variable[tp.Any]],
static: tuple[tuple[str, tp.Any], ...],
children: tuple[tp.Any],
):
return cls.from_metadata(value=children[0], attributes=dict(static))
jax.tree_util.register_pytree_with_keys(
Variable,
flatten_with_keys=_variable_flatten_with_keys,
unflatten_func=partial(_variable_unflatten, Variable), # type: ignore
flatten_func=_variable_flatten,
)
VariableState = Variable
[docs]class Param(Variable[A]):
"""The canonical learnable parameter. All learnable parameters
in NNX layer modules will have the ``Param`` :class:`Variable`
type::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
'bias': Param(
value=(3,)
),
'kernel': Param(
value=(2, 3)
)
})
"""
pass
[docs]class BatchStat(Variable[A]):
"""The mean and variance batch statistics stored in
the :class:`BatchNorm` layer. Note, these are not the
learnable scale and bias parameters, but rather the
running average statistics that are typically used
during post-training inference::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
'bias': Param(
value=(3,)
),
'mean': BatchStat(
value=(3,)
),
'scale': Param(
value=(3,)
),
'var': BatchStat(
value=(3,)
)
})
"""
pass
[docs]class Cache(Variable[A]):
"""Autoregressive cache in :class:`MultiHeadAttention`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
... num_heads=2,
... in_features=3,
... qkv_features=6,
... out_features=6,
... decode=True,
... rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
'cache_index': Cache(
value=()
),
'cached_key': Cache(
value=(1, 2, 3)
),
'cached_value': Cache(
value=(1, 2, 3)
)
})
"""
pass
class Perturbation(Intermediate[A]):
""":class:`Variable` type that is typically used for
:func:`Module.perturb`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x):
... x = self.linear1(x)
... x = self.perturb('i', x)
... x = self.linear2(x)
... return x
>>> model = Model(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = model(x)
>>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Perturbation))
State({
'i': Perturbation(
value=(1, 3)
)
})
"""
pass
def split_flat_state(
flat_state: tp.Iterable[tuple[PathParts, Variable]],
filters: tuple[filterlib.Filter, ...],
) -> tuple[list[tuple[PathParts, Variable]], ...]:
predicates = filterlib.filters_to_predicates(filters)
# we have n + 1 states, where n is the number of predicates
# the last state is for values that don't match any predicate
flat_states: tuple[list[tuple[PathParts, Variable]], ...] = (
tuple([] for _ in predicates)
)
for path, value in flat_state:
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i].append((path, value))
break
else:
raise ValueError(
'Non-exhaustive filters, got a non-empty remainder: '
f'{path} -> {value}.'
'\nUse `...` to match all remaining elements.'
)
return flat_states
###################################################
### Variable type/class <-> string name mapping ###
###################################################
# Assumption: the mapping is 1-1 and unique.
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
[docs]def variable_type_from_name(
name: str,
/,
*,
base: type[Variable[tp.Any]] = Variable,
allow_register: bool = False,
) -> tp.Type[Variable[tp.Any]]:
"""Given a Linen-style collection name, get or create its NNX Variable class."""
if name not in VariableTypeCache:
if not allow_register:
raise ValueError(
f'Name {name} is not registered in the registry. '
'To register a new name, use register_variable_name() '
'or set allow_register=True.'
)
VariableTypeCache[name] = type(name, (base,), {})
return VariableTypeCache[name]
[docs]def variable_name_from_type(
typ: tp.Type[Variable[tp.Any]], /, *, allow_register: bool = False
) -> str:
"""Given an NNX Variable type, get its Linen-style collection name.
Should output the exact inversed result of `variable_type_from_name()`."""
for name, t in VariableTypeCache.items():
if typ == t:
return name
if not allow_register:
raise ValueError(
f'Type {typ} is not registered in the registry. '
'To register a new type, use register_variable_name() '
'or set allow_register=True.'
)
name = typ.__name__
if name in VariableTypeCache:
raise ValueError(
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
'It cannot be linked with this type {typ}.'
)
register_variable_name(name, typ)
return name
class _Missing:
pass
_MISSING = _Missing()
@tp.overload
def register_variable_name(
name: str,
typ: type[Variable[tp.Any]],
*,
overwrite: bool = False,
) -> type[Variable[tp.Any]]: ...
@tp.overload
def register_variable_name(
name: str,
*,
overwrite: bool = False,
) -> tp.Callable[[type[Variable[tp.Any]]], type[Variable[tp.Any]]]: ...
[docs]def register_variable_name(
name: str,
typ: type[Variable[A]] | _Missing = _MISSING,
*,
overwrite=False,
) -> type[Variable[A]] | tp.Callable[[type[Variable[A]]], type[Variable[A]]]:
"""Register a pair of Linen collection name and its NNX type."""
if typ is _MISSING:
return partial(register_variable_name, name, overwrite=overwrite)
typ = tp.cast(type[Variable[A]], typ)
if not overwrite and name in VariableTypeCache:
raise ValueError(
f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call register_variable_name() with `overwrite=True`.'
)
VariableTypeCache[name] = typ
return typ
# add known variable type names
register_variable_name('params', Param)
register_variable_name('batch_stats', BatchStat)
register_variable_name('cache', Cache)
register_variable_name('intermediates', Intermediate)
register_variable_name('perturbations', Perturbation)