Source code for flax.nnx.helpers

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import inspect
import typing as tp

import jax
import jax.numpy as jnp
import optax

from flax.nnx import graph, reprlib
from flax.nnx.graph import GraphDef
from flax.nnx.module import Module
from flax.nnx.proxy_caller import ApplyCaller
from flax.nnx.rnglib import Rngs
from flax.nnx.statelib import State
from flax.training.train_state import struct

A = tp.TypeVar('A')
M = tp.TypeVar('M', bound=Module)
TS = tp.TypeVar('TS', bound='TrainState')

[docs]class Dict(reprlib.MappingReprMixin, Module, tp.MutableMapping[str, A]): """A Module that implements a mutable mapping. This class provides a way to store and manipulate a mapping of keys to values contained a mixed set of data (e.g. Array, Variables, Modules) and static (e.g. functions, strings) types. Example: >>> from flax import nnx ... >>> rngs = nnx.Rngs(0) >>> layers = nnx.Dict({ ... 'kernel1': nnx.Linear(1, 32, rngs=rngs), # data ... 'activation1': nnx.relu, # static ... 'kernel2': nnx.Linear(32, 1, rngs=rngs), # data ... }) """ @tp.overload def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ... @tp.overload def __init__( self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A ): ... def __init__(self, *args, **kwargs): for name, value in dict(*args, **kwargs).items(): setattr(self, name, value) def __getitem__(self, key) -> A: return getattr(self, key) def __setitem__(self, key, value): setattr(self, key, value) def __iter__(self) -> tp.Iterator[str]: return (k for k in vars(self) if not k.startswith('_pytree__')) def __len__(self) -> int: length = 0 for _ in self: length += 1 return length def __hash__(self) -> int: return id(self) def __delitem__(self, key: str) -> None: delattr(self, key) if tp.TYPE_CHECKING: def __getattr__(self, key: str) -> A: ... def __setattr__(self, key: str, value: A) -> None: ...
[docs]class List(reprlib.SequenceReprMixin, Module, tp.MutableSequence[A]): """A Module that implements a mutable sequence. This class provides a way to store and manipulate a sequence of values contained a mixed set of data (e.g. Array, Variables, Modules) and static (e.g. functions, strings) types. Example: >>> from flax import nnx ... >>> rngs = nnx.Rngs(0) >>> layers = nnx.List([ ... nnx.Linear(1, 32, rngs=rngs), # data ... nnx.relu, # static ... nnx.Linear(32, 1, rngs=rngs), # data ... ]) """ def __init__(self, it: tp.Iterable[A] | None = None, /): """ Args: it: An iterable of values to initialize the list. """ self._length: int = 0 if it is not None: for value in it: self.append(value) def _getattr(self, key) -> A: return vars(self)[key] # type: ignore[unsupported-operands] def _delattr(self, key) -> None: vars(self).pop(key) def __len__(self) -> int: return self._length
[docs] def append(self, value: A) -> None: self._setattr(self._length, value) self._length += 1
[docs] def insert(self, index: int, value: A) -> None: if index < 0: index += self._length if index < 0: index = 0 if index > self._length: index = self._length # Shift elements to the right for i in range(self._length, index, -1): self._setattr(i, self._getattr(i - 1)) # Insert the new value self._setattr(index, value) self._length += 1
def __iter__(self) -> tp.Iterator[A]: for i in range(self._length): yield self._getattr(i) @tp.overload def __getitem__(self, index: int) -> A: ... @tp.overload def __getitem__(self, index: slice) -> tp.List[A]: ... def __getitem__(self, index: int | slice) -> A | tp.List[A]: if isinstance(index, int): if index < 0: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') return self._getattr(index) elif isinstance(index, slice): idxs = list(range(self._length))[index] return [self._getattr(i) for i in idxs] else: raise TypeError('Invalid index type') def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None: if isinstance(index, int): if index < 0: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') self._setattr(index, value) elif isinstance(index, slice): if not isinstance(value, tp.Iterable): raise TypeError('Expected an iterable') values = list(value) idxs = list(range(self._length))[index] if len(idxs) != len(values): raise ValueError('Length mismatch') for i, v in zip(idxs, values): self._setattr(i, v) else: raise TypeError('Invalid index type') def __delitem__(self, index: int | slice) -> None: if isinstance(index, int): if index < 0: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') self._delattr(index) for i in range(index + 1, self._length): self._setattr(i - 1, self._getattr(i)) self._length -= 1 elif isinstance(index, slice): idxs = list(range(self._length))[index] for i in reversed(idxs): # implement recursively del self[i] else: raise TypeError('Invalid index type') @staticmethod def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: key, _ = item if isinstance(key, int): return (0, key) elif isinstance(key, str): return (1, key) else: raise ValueError(f'Unsupported key type: {type(key)!r}')
[docs]class Sequential(Module): """A Module that applies a sequence of callables. This class provides a way to store and manipulate a sequence of callables (e.g. layers, activation functions) and apply them in order. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> rngs = nnx.Rngs(0) >>> model = nnx.Sequential( ... nnx.Linear(1, 4, rngs=rngs), # data ... nnx.relu, # static ... nnx.Linear(4, 2, rngs=rngs), # data ... ) >>> x = jnp.ones((1, 1)) >>> y = model(x) >>> y.shape (1, 2) """ def __init__(self, *fns: tp.Callable[..., tp.Any]): """ Args: *fns: A sequence of callables to apply. """ self.layers = List(fns) def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: if len(self.layers) == 0: if len(args) == 1: return args[0] elif len(args) > 0: return args elif len(kwargs) > 0: return kwargs else: return None output: tp.Any = None for i, f in enumerate(self.layers): if not callable(f): raise TypeError(f'Sequence[{i}] is not callable: {f}') if i > 0: if isinstance(output, tuple): args = output kwargs = {} elif isinstance(output, dict): args = () kwargs = output else: args = (output,) kwargs = {} if rngs is not None and has_keyword_arg(f, 'rngs'): kwargs['rngs'] = rngs output = f(*args, **kwargs) return output
class ModuleDefApply(tp.Protocol, tp.Generic[M]): def __call__( self, state: State, *states: State ) -> ApplyCaller[tuple[State, GraphDef[M]]]: ...
[docs]class TrainState(tp.Generic[M], struct.PyTreeNode): graphdef: graph.GraphDef[M] params: State opt_state: optax.OptState step: jax.Array tx: optax.GradientTransformation = struct.field(pytree_node=False) @classmethod def create( cls, graphdef: graph.GraphDef[M], *, params: State, tx: optax.GradientTransformation, step: int = 0, **kwargs, ): return cls( graphdef=graphdef, params=params, opt_state=tx.init(params), step=jnp.asarray(step), tx=tx, **kwargs, ) if tp.TYPE_CHECKING: def __getattr__(self, key: str) -> tp.Any: ... def apply( self, state: tp.Union[State, str], *states: tp.Union[State, str] ) -> ApplyCaller[tuple[GraphDef[M], State]]: states = (state, *states) _states: list[State] = [] for _state in states: if isinstance(_state, str): _state_key = _state _state = getattr(self, _state_key) if not isinstance(_state, State): raise TypeError( f'Expected {self.__class__.__name__}.{_state_key} to be a State, got {type(_state)}' ) _states.append(_state) return self.graphdef.apply(*_states) def apply_gradients(self: TS, grads: State, **kwargs) -> TS: updates, opt_state = self.tx.update(grads, self.opt_state, self.params) params = optax.apply_updates(self.params, updates) # type: ignore step = self.step + 1 return self.replace( params=params, opt_state=opt_state, step=step, **kwargs, )
def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool: """Return True if func has keyword-only arguments with the given name.""" return any( param.name == name and param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD) for param in inspect.signature(func).parameters.values() )