Source code for flax.nnx.spmd

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing as tp

import flax.core.spmd as core_spmd
from flax.nnx import variablelib, graph
from flax.nnx.transforms.transforms import eval_shape
from flax.typing import (
  Sharding,
)
import jax
from jax.sharding import PartitionSpec

# JAX version compatibility
if hasattr(jax.sharding, 'use_mesh'):
  set_mesh = jax.sharding.use_mesh
else:
  set_mesh = jax.set_mesh

A = tp.TypeVar('A')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
PARTITION_NAME = 'partition_name'


# Transform axis change helpers
# ------------------------------------------------------------------------------


def add_axis(tree: A, index: int, transform_metadata: tp.Mapping) -> A:
  axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata)

  def insert_field(fields, index, value):
    iterable = list(fields)
    while len(iterable) < index:
      iterable.append(None)
    iterable.insert(index, value)
    return tuple(iterable)

  def _add_axis(x: tp.Any):
    if isinstance(x, variablelib.Variable):
      metadata = x.get_metadata()
      if 'sharding_names' in metadata and metadata['sharding_names']:
        sharding = metadata['sharding_names']
        x.sharding_names = insert_field(sharding, index, axis_name)

      for k, v in other_meta.items():
        if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
          setattr(x, k, insert_field(t, index, v))

      assert isinstance(x, variablelib.Variable)
      x.add_axis(index, axis_name)
    return x

  return jax.tree.map(
    _add_axis, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable)
  )


def remove_axis(
  tree: A, index: int, transform_metadata: tp.Mapping[tp.Any, tp.Any]
) -> A:
  axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata)

  def remove_field(fields, index, value):
    iterable = list(fields)
    assert iterable.pop(index) == value
    return tuple(iterable)

  def _remove_axis(x: tp.Any):
    if isinstance(x, variablelib.Variable):
      if hasattr(x, 'sharding_names') and x.sharding_names is not None:
        x.sharding_names = remove_field(x.sharding_names, index, axis_name)

      for k, v in other_meta.items():
        if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
          setattr(x, k, remove_field(t, index, v))

      x.remove_axis(index, axis_name)
    return x

  return jax.tree.map(
    _remove_axis,
    tree,
    is_leaf=lambda x: isinstance(x, variablelib.Variable),
  )


def _get_partition_name_and_metadata(
  transform_metadata: tp.Mapping[tp.Any, tp.Any],
) -> tuple[str, tp.Mapping[tp.Any, tp.Any]]:
  if PARTITION_NAME not in transform_metadata:
    raise ValueError(
      'Trying to transform a Partitioned variable but "partition_name" '
      f'is not specified in transform_metadata: {transform_metadata}'
    )
  other_meta = dict(transform_metadata)  # shallow copy
  other_meta.pop(PARTITION_NAME)
  return transform_metadata[PARTITION_NAME], other_meta


# Annotation handling
# ------------------------------------------------------------------------------


[docs]def with_partitioning( initializer: F, sharding: Sharding, mesh: tp.Optional[jax.sharding.Mesh] = None, **metadata: tp.Any, ) -> F: """A wrapper over any initializer to add sharding annotation data to a `Variable`.""" return variablelib.with_metadata( initializer, sharding_names=sharding, mesh=mesh, **metadata, )
def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None: """Given an `nnx.Variable`, return its `PartitionSpec`.""" metadata = v.get_metadata() if 'sharding_names' in metadata and metadata['sharding_names']: sharding = metadata['sharding_names'] if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata: context_rules = core_spmd.get_logical_axis_rules() local_rules = metadata.get('sharding_rules', ()) rules = core_spmd.composite_rules(context_rules, local_rules) return PartitionSpec(*core_spmd.from_sharding_rules(sharding, rules)) return PartitionSpec(*sharding) elif hasattr(v, 'shape'): return PartitionSpec() return None
[docs]def get_partition_spec(tree: A) -> A: """Extracts a PartitionSpec tree from a PyTree containing ``Variable`` values.""" def f(x): if isinstance(x, variablelib.Variable): return x.replace(get_var_pspec(x)) elif hasattr(x, 'shape'): return PartitionSpec() return None return jax.tree.map( f, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) )
[docs]def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A: spec = get_partition_spec(tree) sharding = jax.tree.map(lambda p: jax.sharding.NamedSharding(mesh, p), spec) return sharding
# Other utilities # ------------------------------------------------------------------------------ def get_abstract_model(init_fn, mesh): with set_mesh(mesh): abs_model = eval_shape(init_fn) gdef, abs_state = graph.split(abs_model) abs_state = jax.tree.map( lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_state, get_named_sharding(abs_state, mesh) ) return gdef, abs_state