# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared neural network activations and other functions."""
from __future__ import annotations
from collections.abc import Sequence
from functools import partial
import operator
import math
import numpy as np
from typing import Any, Literal
import warnings
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import deprecations
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src import util
from jax._src.core import AxisName
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper as cudnn_scaled_matmul,
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
BlockScaleConfig)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.numpy import einsum as jnp_einsum
from jax._src.numpy import util as numpy_util
from jax._src.numpy.reductions import _count
from jax._src.numpy.reductions import Axis
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.ops.special import logsumexp as _logsumexp
# activations
[docs]@api.jit
def identity(x: ArrayLike) -> Array:
r"""Identity activation function.
Returns the argument unmodified.
Args:
x : input array
Returns:
The argument `x` unmodified.
Examples:
>>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)
"""
numpy_util.check_arraylike("identity", x)
return jnp.asarray(x)
[docs]@custom_derivatives.custom_jvp
@api.jit
def relu(x: ArrayLike) -> Array:
r"""Rectified linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{relu}(x) = \max(x, 0)
except under differentiation, we take:
.. math::
\nabla \mathrm{relu}(0) = 0
For more information see
`Numerical influence of ReLU’(0) on backpropagation
<https://dl.acm.org/doi/10.5555/3540261.3540297>`_.
Args:
x : input array
Returns:
An array.
Examples:
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
See also:
:func:`relu6`
"""
return jnp.maximum(x, 0)
# For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
@api.jit
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array:
r"""Squareplus activation function.
Computes the element-wise function
.. math::
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
as described in https://arxiv.org/abs/2112.11687.
Args:
x : input array
b : smoothness parameter
"""
numpy_util.check_arraylike("squareplus", x)
numpy_util.check_arraylike("squareplus", b)
x = jnp.asarray(x)
b = jnp.asarray(b)
y = x + jnp.sqrt(jnp.square(x) + b)
return y / 2
[docs]@api.jit
def softplus(x: ArrayLike) -> Array:
r"""Softplus activation function.
Computes the element-wise function
.. math::
\mathrm{softplus}(x) = \log(1 + e^x)
Args:
x : input array
"""
return jnp.logaddexp(x, 0)
@api.jit
def sparse_plus(x: ArrayLike) -> Array:
r"""Sparse plus function.
Computes the function:
.. math::
\mathrm{sparse\_plus}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
x, & 1 \leq x
\end{cases}
This is the twin function of the softplus activation ensuring a zero output
for inputs less than -1 and a linear output for inputs greater than 1,
while remaining smooth, convex, monotonic by an adequate definition between
-1 and 1.
Args:
x: input (float)
"""
numpy_util.check_arraylike("sparse_plus", x)
x = jnp.asarray(x)
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
[docs]@api.jit
def soft_sign(x: ArrayLike) -> Array:
r"""Soft-sign activation function.
Computes the element-wise function
.. math::
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
Args:
x : input array
"""
numpy_util.check_arraylike("soft_sign", x)
x_arr = jnp.asarray(x)
return x_arr / (jnp.abs(x_arr) + 1)
[docs]@partial(api.jit, inline=True)
def sigmoid(x: ArrayLike) -> Array:
r"""Sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
Args:
x : input array
Returns:
An array.
See also:
:func:`log_sigmoid`
"""
return lax.logistic(x)
@api.jit
def sparse_sigmoid(x: ArrayLike) -> Array:
r"""Sparse sigmoid activation function.
Computes the function:
.. math::
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{2}(x+1), & -1 < x < 1 \\
1, & 1 \leq x
\end{cases}
This is the twin function of the ``sigmoid`` activation ensuring a zero output
for inputs less than -1, a 1 output for inputs greater than 1, and a linear
output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
For more information, see `Learning with Fenchel-Young Losses (section 6.2)
<https://arxiv.org/abs/1901.02324>`_.
Args:
x : input array
Returns:
An array.
See also:
:func:`sigmoid`
"""
return 0.5 * jnp.clip(x + 1.0, 0.0, 2.0)
[docs]@api.jit
def silu(x: ArrayLike) -> Array:
r"""SiLU (aka swish) activation function.
Computes the element-wise function:
.. math::
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
:func:`swish` and :func:`silu` are both aliases for the same function.
Args:
x : input array
Returns:
An array.
See also:
:func:`sigmoid`
"""
numpy_util.check_arraylike("silu", x)
x_arr = jnp.asarray(x)
return x_arr * sigmoid(x_arr)
swish = silu
@api.jit
def mish(x: ArrayLike) -> Array:
r"""Mish activation function.
Computes the element-wise function:
.. math::
\mathrm{mish}(x) = x \cdot \mathrm{tanh}(\mathrm{softplus}(x))
For more information, see
`Mish: A Self Regularized Non-Monotonic Activation Function
<https://arxiv.org/abs/1908.08681>`_.
Args:
x : input array
Returns:
An array.
"""
numpy_util.check_arraylike("mish", x)
x_arr = jnp.asarray(x)
return x_arr * jnp.tanh(softplus(x_arr))
[docs]@api.jit
def log_sigmoid(x: ArrayLike) -> Array:
r"""Log-sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
Args:
x : input array
Returns:
An array.
See also:
:func:`sigmoid`
"""
numpy_util.check_arraylike("log_sigmoid", x)
x_arr = jnp.asarray(x)
return -softplus(-x_arr)
[docs]@api.jit
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
r"""Exponential linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{elu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(x) - 1\right), & x \le 0
\end{cases}
Args:
x : input array
alpha : scalar or array of alpha values (default: 1.0)
Returns:
An array.
See also:
:func:`selu`
"""
numpy_util.check_arraylike("elu", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr > 0,
x_arr,
alpha * jnp.expm1(jnp.where(x_arr > 0, 0., x_arr)))
[docs]@api.jit
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array:
r"""Leaky rectified linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{leaky\_relu}(x) = \begin{cases}
x, & x \ge 0\\
\alpha x, & x < 0
\end{cases}
where :math:`\alpha` = :code:`negative_slope`.
Args:
x : input array
negative_slope : array or scalar specifying the negative slope (default: 0.01)
Returns:
An array.
See also:
:func:`relu`
"""
numpy_util.check_arraylike("leaky_relu", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr >= 0, x_arr, negative_slope * x_arr)
[docs]@api.jit
def hard_tanh(x: ArrayLike) -> Array:
r"""Hard :math:`\mathrm{tanh}` activation function.
Computes the element-wise function:
.. math::
\mathrm{hard\_tanh}(x) = \begin{cases}
-1, & x < -1\\
x, & -1 \le x \le 1\\
1, & 1 < x
\end{cases}
Args:
x : input array
Returns:
An array.
"""
numpy_util.check_arraylike("hard_tanh", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr > 1, 1, jnp.where(x_arr < -1, -1, x_arr))
[docs]@api.jit
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
r"""Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
.. math::
\mathrm{celu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
\end{cases}
For more information, see
`Continuously Differentiable Exponential Linear Units
<https://arxiv.org/abs/1704.07483>`_.
Args:
x : input array
alpha : array or scalar (default: 1.0)
Returns:
An array.
"""
return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)
[docs]@api.jit
def selu(x: ArrayLike) -> Array:
r"""Scaled exponential linear unit activation.
Computes the element-wise function:
.. math::
\mathrm{selu}(x) = \lambda \begin{cases}
x, & x > 0\\
\alpha e^x - \alpha, & x \le 0
\end{cases}
where :math:`\lambda = 1.0507009873554804934193349852946` and
:math:`\alpha = 1.6732632423543772848170429916717`.
For more information, see
`Self-Normalizing Neural Networks
<https://arxiv.org/abs/1706.02515>`_.
Args:
x : input array
Returns:
An array.
See also:
:func:`elu`
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * elu(x, alpha)
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
# @partial(api.jit, static_argnames=("approximate",))
[docs]def gelu(x: ArrayLike, approximate: bool = True) -> Array:
r"""Gaussian error linear unit activation function.
If ``approximate=False``, computes the element-wise function:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left(
\frac{-x}{\sqrt{2}} \right) \right)
If ``approximate=True``, uses the approximate formulation of GELU:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
For more information, see `Gaussian Error Linear Units (GELUs)
<https://arxiv.org/abs/1606.08415>`_, section 2.
Args:
x: input array
approximate: whether to use the approximate or exact formulation.
"""
[x_arr] = numpy_util.promote_args_inexact("gelu", x)
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x_arr.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3))))
return x_arr * cdf
else:
sqrt_half = np.sqrt(0.5).astype(x_arr.dtype)
return jnp.array(
0.5 * x_arr * (lax.erfc(-x_arr * sqrt_half)), dtype=x_arr.dtype
)
[docs]@partial(api.jit, static_argnames=("axis",))
def glu(x: ArrayLike, axis: int = -1) -> Array:
r"""Gated linear unit activation function.
Computes the function:
.. math::
\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
\mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
\right)
where the array is split into two along ``axis``. The size of the ``axis``
dimension must be divisible by two.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
Returns:
An array.
See also:
:func:`sigmoid`
"""
numpy_util.check_arraylike("glu", x)
x_arr = jnp.asarray(x)
size = x_arr.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = jnp.split(x_arr, 2, axis)
return x1 * sigmoid(x2)
# other functions
logsumexp = _logsumexp
@partial(api.jit, static_argnames=("axis", "keepdims"))
def logmeanexp(
x: ArrayLike,
axis: int | tuple[int, ...] | None = None,
where: ArrayLike | None = None,
keepdims: bool = False,
) -> Array:
r"""Log mean exp.
Computes the function:
.. math::
\text{logmeanexp}(x) = \log \frac{1}{n} \sum_{i=1}^n \exp x_i = \text{logsumexp}(x) - \log n
Args:
x: Input array.
axis: Axis or axes along which to reduce.
where: Elements to include in the reduction. Optional.
keepdims: Preserve the dimensions of the input.
Returns:
An array.
See also:
:func:`jax.nn.logsumexp`
"""
lse = _logsumexp(x, axis=axis, where=where, keepdims=keepdims)
count = _count(x, axis=axis, where=where, keepdims=keepdims, dtype=lse.dtype)
return lse - jnp.log(count)
[docs]@partial(api.jit, static_argnames=("axis",))
def log_softmax(x: ArrayLike,
axis: Axis = -1,
where: ArrayLike | None = None) -> Array:
r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales
elements to the range :math:`[-\infty, 0)`.
.. math ::
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)
Args:
x : input array
axis: the axis or axes along which the :code:`log_softmax` should be
computed. Either an integer or a tuple of integers.
where: Elements to include in the :code:`log_softmax`. The output for any
masked-out element is minus infinity.
Returns:
An array.
Note:
If any input values are ``+inf``, the result will be all ``NaN``: this reflects the
fact that ``inf / inf`` is not well-defined in the context of floating-point math.
See also:
:func:`softmax`
"""
numpy_util.check_arraylike("log_softmax", x)
x_arr = jnp.asarray(x)
x_max = jnp.max(x_arr, axis, where=where, initial=-np.inf, keepdims=True)
x_safe = x_arr if where is None else jnp.where(where, x_arr, -np.inf)
shifted = x_safe - lax.stop_gradient(x_max)
shifted_logsumexp = jnp.log(
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
result = shifted - shifted_logsumexp
if where is not None:
return jnp.where(where, result, -np.inf)
return result
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
# @partial(api.jit, static_argnames=("axis",))
[docs]def softmax(x: ArrayLike,
axis: Axis = -1,
where: ArrayLike | None = None) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
x : input array
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
where: Elements to include in the :code:`softmax`. The output for any
masked-out element is zero.
Returns:
An array.
Note:
If any input values are ``+inf``, the result will be all ``NaN``: this reflects the
fact that ``inf / inf`` is not well-defined in the context of floating-point math.
See also:
:func:`log_softmax`
"""
if config.softmax_custom_jvp.value:
# mypy is confused by the `functools.partial` application in the definition
# of `_softmax` and incorrectly concludes that `_softmax` returns
# `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`.
return _softmax(x, axis, where)
else:
return _softmax_deprecated(x, axis, where)
# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,))
def _softmax(
x: ArrayLike,
axis: Axis = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None = -np.inf) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
x_safe = x if where is None else jnp.where(where, x, initial)
unnormalized = jnp.exp(x_safe - x_max)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result
@_softmax.defjvp
def _softmax_jvp(axis, primals, tangents):
(x, where, initial), (x_dot, _, _) = primals, tangents
y = _softmax(x, axis, where, initial)
return y, y * (x_dot - (y * x_dot).sum(axis, where=where, keepdims=True))
def _softmax_deprecated(
x: ArrayLike,
axis: Axis = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None = -np.inf) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
x_safe = x if where is None else jnp.where(where, x, initial)
unnormalized = jnp.exp(x_safe - lax.stop_gradient(x_max))
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result
[docs]@partial(api.jit, static_argnames=("axis",))
def standardize(x: ArrayLike,
axis: Axis = -1,
mean: ArrayLike | None = None,
variance: ArrayLike | None = None,
epsilon: ArrayLike = 1e-5,
where: ArrayLike | None = None) -> Array:
r"""Standardizes input to zero mean and unit variance.
The standardization is given by:
.. math::
x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}
where :math:`\langle x\rangle` indicates the mean of :math:`x`, and :math:`\epsilon` is
a small correction factor introduced to avoid division by zero.
Args:
x: input array to be standardized.
axis: integer or tuple of integers representing the axes along which
to standardize. Defaults to the last axis (``-1``).
mean: optionally specify the mean used for standardization. If not specified,
then ``x.mean(axis, where=where)`` will be used.
variance: optionally specify the variance used for standardization. If not
specified, then ``x.var(axis, where=where)`` will be used.
epsilon: correction factor added to variance to avoid division by zero; defaults
to ``1E-5``.
where: optional boolean mask specifying which elements to use when computing
the mean and variance.
Returns:
An array of the same shape as ``x`` containing the standardized input.
"""
numpy_util.check_arraylike("standardize", x)
numpy_util.check_arraylike_or_none("standardize", mean, variance, where)
if mean is None:
mean = jnp.mean(x, axis, keepdims=True, where=where)
if variance is None:
# this definition is traditionally seen as less accurate than jnp.var's
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = jnp.mean(
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
@partial(api.jit, static_argnames=("num_classes", "dtype", "axis"))
def _one_hot(x: Array, num_classes: int, *,
dtype: DTypeLike, axis: int | AxisName) -> Array:
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
try:
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) # type: ignore[arg-type]
except TypeError:
axis_size = lax.axis_size(axis)
if num_classes != axis_size:
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(_dot_product_attention_xla == axis_idx, dtype=dtype)
axis = operator.index(axis) # type: ignore[arg-type]
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
out_sharding=rhs_sharding)
return (lhs == rhs).astype(dtype)
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
[docs]def one_hot(x: Any, num_classes: int, *,
dtype: Any | None = None, axis: int | AxisName = -1) -> Array:
"""One-hot encodes the given indices.
Each index in the input ``x`` is encoded as a vector of zeros of length
``num_classes`` with the element at ``index`` set to one::
>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
Indices outside the range [0, num_classes) will be encoded as zeros::
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
Args:
x: A tensor of indices.
num_classes: Number of classes in the one-hot dimension.
dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
axis: the axis or axes along which the function should be
computed.
"""
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
x_arr = jnp.asarray(x)
if not dtypes.isdtype(x_arr.dtype, "integral"):
# Deprecated 2024-12-18
deprecations.warn(
'jax-nn-one-hot-float-input',
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
stacklevel=1)
dtype = dtypes.default_float_dtype() if dtype is None else dtype
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)
@custom_derivatives.custom_jvp
@api.jit
def relu6(x: ArrayLike) -> Array:
r"""Rectified Linear Unit 6 activation function.
Computes the element-wise function
.. math::
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
except under differentiation, we take:
.. math::
\nabla \mathrm{relu}(0) = 0
and
.. math::
\nabla \mathrm{relu}(6) = 0
Args:
x : input array
Returns:
An array.
See also:
:func:`relu`
"""
return jnp.minimum(jnp.maximum(x, 0), 6.)
relu6.defjvps(lambda g, ans, x:
lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0)))
[docs]@api.jit
def hard_sigmoid(x: ArrayLike) -> Array:
r"""Hard Sigmoid activation function.
Computes the element-wise function
.. math::
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
Args:
x : input array
Returns:
An array.
See also:
:func:`relu6`
"""
return relu6(x + 3.) / 6.
[docs]@api.jit
def hard_silu(x: ArrayLike) -> Array:
r"""Hard SiLU (swish) activation function
Computes the element-wise function
.. math::
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
function.
Args:
x : input array
Returns:
An array.
See also:
:func:`hard_sigmoid`
"""
numpy_util.check_arraylike("hard_silu", x)
x_arr = jnp.asarray(x)
return x_arr * hard_sigmoid(x_arr)
hard_swish = hard_silu
def _get_large_negative(dtype):
dtype_max = dtypes.finfo(dtype).max
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
def _get_causal_mask(T, S):
mask = jnp.tril(jnp.ones((T, S), dtype=bool))
return mask[None, None, :, :]
def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]):
query_pos = jnp.array(range(T))
key_pos = jnp.array(range(S))
left_window, right_window = local_window_size
left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window
right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window
return jnp.logical_and(right_mask, left_mask)[None, None, :, :]
def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
q_mask = True
kv_mask = True
if q_seqlen is not None:
q_indices = jnp.arange(0, T)[None, :, None]
q_mask = q_indices < q_seqlen[:, None, None]
if kv_seqlen is not None:
kv_indices = jnp.arange(0, S)[None, None, :]
kv_mask = kv_indices < kv_seqlen[:, None, None]
mask = jnp.logical_and(q_mask, kv_mask)
return mask[:, None, :, :]
def _get_padding_mask_encoded(T, q_seqlen):
q_indices = jnp.arange(0, T)[None, :]
mask = q_indices < q_seqlen[:, None]
return mask[:, :, None, None]
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
local_window_size):
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
return logits
combined_mask = jnp.ones_like(logits, dtype=bool)
if mask is not None:
assert mask.dtype == np.dtype(bool)
combined_mask = jnp.logical_and(combined_mask, mask)
T, S = logits.shape[2], logits.shape[3]
if is_causal:
mask = _get_causal_mask(T, S)
combined_mask = jnp.logical_and(combined_mask, mask)
if local_window_size is not None:
mask = _get_window_mask(T, S, local_window_size)
combined_mask = jnp.logical_and(combined_mask, mask)
if q_seqlen is not None or kv_seqlen is not None:
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
combined_mask = jnp.logical_and(combined_mask, mask)
large_negative_number = _get_large_negative(logits.dtype)
padded_logits = jnp.where(combined_mask, logits, large_negative_number)
return padded_logits
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
scale, q_seqlen, kv_seqlen, local_window_size):
logits_dtype = jnp.promote_types(query.dtype, np.float32)
# If the query and logits dtypes are different, then the default precision
# can use inconsistent types in the backwards pass
# (see https://github.com/jax-ml/jax/issues/24047).
if query.dtype == dtypes.bfloat16:
precision = lax.DotAlgorithmPreset.BF16_BF16_F32
elif query.dtype == np.float16:
precision = lax.DotAlgorithmPreset.F16_F16_F32
# TODO(sbodenstein): Implement this fix for all dtypes.
else:
precision = None
# Explicit precision will fail on platforms that don't support it. For example,
# some GPUs do not support BF16_BF16_F32, and TPU does not support F16_F16_F32.
# Use the default precision as a fallback in these cases.
try:
logits = jnp_einsum.einsum(
"BTNH,BSNH->BNTS",
query,
key,
precision=precision,
preferred_element_type=logits_dtype,
)
except: # pylint: disable=bare-except
logits = jnp_einsum.einsum(
"BTNH,BSNH->BNTS",
query,
key,
precision=None,
preferred_element_type=logits_dtype,
)
logits *= jnp.array(scale, dtype=logits.dtype)
if bias is not None:
logits = (logits + bias).astype(logits.dtype)
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
local_window_size)
# Softmax and it is always carried out in fp32.
padded_logits = padded_logits.astype(np.float32)
probs = softmax(padded_logits, axis=-1).astype(key.dtype)
encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value)
if q_seqlen is not None and kv_seqlen is not None:
mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen)
encoded *= mask.astype(encoded.dtype)
return encoded
def _dot_product_attention_xla(
query: Array,
key: Array,
value: Array,
bias: Array | None,
mask: Array | None,
is_causal: bool,
scale: float,
q_seqlen: Array | None,
kv_seqlen: Array | None,
local_window_size: tuple[int, int] | None):
B, T, N, H = query.shape
_, S, K, _ = key.shape
G = N // K
query = jnp.reshape(query, (B, T, K, G, H))
def _reshape_to_grouped(t):
if t is not None:
tB, tN, tT, tS = t.shape
if tN == 1:
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
else:
assert tN == N
t = jnp.reshape(t, (tB, K, G, tT, tS))
return t
bias = _reshape_to_grouped(bias)
mask = _reshape_to_grouped(mask)
vmapped_fn = api.vmap(
_dot_product_attention_core,
in_axes=(3, None, None, 2, 2, None, None, None, None, None),
out_axes=3,
)
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale,
q_seqlen, kv_seqlen, local_window_size)
encoded = jnp.reshape(encoded, (B, T, N, H))
return encoded
def bias_fwd_rule(a, query_head_num):
return bias_fwd_p.bind(a, query_head_num), a
def bias_bwd_rule(query_head_num, res, g):
a = res
if a.shape[0] > 1 or a.shape[-3] != query_head_num:
raise ValueError("cuDNN only supports bias gradient when the batch size is "
f"1 and the head number matches the query, but got "
f"B={a.shape[0]}, N={a.shape[-3]}.")
return (bias_bwd_p.bind(g, a, query_head_num),)
# This function uses two custom primitives, `bias_fwd` and `bias_bwd`, to work
# around a cuDNN issue where bias gradients are only supported when the batch
# size is 1 and the number of heads matches the query.
# TODO(kaixih@nvidia): Remove this workaround once cuDNN resolves the issue.
@partial(custom_derivatives.custom_vjp, nondiff_argnums=(1,))
def check_valid_bias_batch(x, query_head_num):
output, _ = bias_fwd_rule(x, query_head_num)
return output
check_valid_bias_batch.defvjp(bias_fwd_rule, bias_bwd_rule)
bias_fwd_p = core.Primitive('bias_fwd')
bias_fwd_p.multiple_results = False
bias_bwd_p = core.Primitive('bias_bwd')
bias_bwd_p.multiple_results = False
def bias_fwd_impl(a, query_head_num):
return a
def bias_bwd_impl(g, a, query_head_num):
return g
bias_fwd_p.def_impl(bias_fwd_impl)
bias_bwd_p.def_impl(bias_bwd_impl)
def bias_fwd_abstract_eval(a, query_head_num):
return core.ShapedArray(a.shape, a.dtype)
def bias_bwd_abstract_eval(g, a, query_head_num):
return core.ShapedArray(g.shape, g.dtype)
bias_fwd_p.def_abstract_eval(bias_fwd_abstract_eval)
bias_bwd_p.def_abstract_eval(bias_bwd_abstract_eval)
def bias_fwd_lowering(ctx, a, query_head_num):
return [a]
def bias_bwd_lowering(ctx, g, a, query_head_num):
return [g]
mlir.register_lowering(bias_fwd_p, bias_fwd_lowering)
mlir.register_lowering(bias_bwd_p, bias_bwd_lowering)
def bias_fwd_batch_rule(batched_args, batch_dims):
x, query_head_num = batched_args
a = batch_dims[0]
output, _ = bias_fwd_rule(x, query_head_num)
return output, a
def bias_bwd_batch_rule(batched_args, batch_dims):
g, x, query_head_num = batched_args
b = batch_dims[0]
*Bs, _, _, _ = x.shape
B = math.prod(Bs)
x = jnp.reshape(x, (B,) + x.shape[-3:])
output, = bias_bwd_rule(query_head_num, x, g)
return output, b
batching.primitive_batchers[bias_fwd_p] = bias_fwd_batch_rule
batching.primitive_batchers[bias_bwd_p] = bias_bwd_batch_rule
def dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
bias: ArrayLike | None = None,
mask: ArrayLike | None = None,
*,
scale: float | None = None,
is_causal: bool = False,
query_seq_lengths: ArrayLike | None = None,
key_value_seq_lengths: ArrayLike | None = None,
local_window_size: int | tuple[int, int] | None = None,
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
r"""Scaled dot product attention function.
Computes the attention function on Query, Key, and Value tensors:
.. math::
\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
If we define :code:`logits` as the output of :math:`QK^T` and the
:code:`probs` as the output of :math:`softmax`.
Throughout this function, we utilize the following uppercase letters to
represent the shape of array::
B = batch size
S = length of the key/value (source)
T = length of the query (target)
N = number of attention heads
H = dimensions of each attention head
K = number of key/value heads
G = number of groups, which equals to N // K
Args:
query: query array; shape :code:`(BTNH|TNH)`
key: key array: shape :code:`(BSKH|SKH)`. When `K` equals `N`, multi-headed
attention (MHA https://arxiv.org/abs/1706.03762) is performed. Otherwise,
grouped query attention (GQA https://arxiv.org/abs/2305.13245) is
performed if `N` is a multiple of `K`, and multi-query attention (MQA
https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case
of GQA).
value: value array, should have the same shape as the `key` array.
bias: optional, bias array to be added to logits; The shape must be 4D and
be broadcastable to :code:`(BNTS|NTS)`.
mask: optional, mask array used to filter out logits. It is a boolean mask
where `True` indicates the element should take part in attention. For an
additive mask, users should pass it to `bias`. The shape must be 4D and be
broadcastable to :code:`(BNTS|NTS)`.
scale: scale for the logits. If None, the scale will be set to 1 divided by
the square root of query's head dimension (i.e. H).
is_causal: If true, causal attention will be applied. Note, some
implementations like `xla` will generate a mask tensor and apply it to the
logits to mask out the non-causal parts of the attention matrix, but other
implementations like `cudnn` will avoid computing the non-causal regions,
providing speedups.
query_seq_lengths: `int32` array of sequence lengths for query; shape
:code:`(B)`
key_value_seq_lengths: `int32` array of sequence lengths for key and value;
shape :code:`(B)`
local_window_size: Window sizes to make self attention to attend to each
token's local window. If set, this specifies the (left_window_size,
right_window_size) for each token. E.g., if local_window_size == (3, 2)
and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend
to [3, 4, 5, c, 7, 8]. If a single int is given, it will be interpreted as
a symmetric window (window_size, window_size).
implementation: A string to control which implementation backend to use.
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
to `None`, which will automatically select the best available backend.
Note, `cudnn` supports only a subset of shapes/dtypes, and an exception
will be thrown if its not supported.
Returns:
An array of the attention output with the same shape as :code:`query`.
"""
output_shape = jnp.asarray(query).shape
def _ensure_4d(t):
t = jnp.asarray(t)
dims_to_add = 4 - t.ndim
if dims_to_add > 0:
return jnp.expand_dims(t, axis=tuple(range(dims_to_add)))
return t
query_arr = _ensure_4d(query)
key_arr = _ensure_4d(key)
value_arr = _ensure_4d(value)
bias = _ensure_4d(bias) if bias is not None else None
mask = _ensure_4d(mask) if mask is not None else None
if query_seq_lengths is not None:
query_seq_lengths = jnp.asarray(query_seq_lengths)
if key_value_seq_lengths is not None:
key_value_seq_lengths = jnp.asarray(key_value_seq_lengths)
if isinstance(local_window_size, int):
local_window_size = (local_window_size, local_window_size)
def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
dtype: DType | None, name: str) -> None:
if t is None:
return
if t.ndim != len(shape):
raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}")
if dtype is not None and t.dtype != dtype:
raise ValueError(f"{name} dtype should be {dtype}, but got {t.dtype}")
for i in range(t.ndim):
if shape[i] != -1 and t.shape[i] != shape[i]:
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}")
B, S, K, H = key_arr.shape
_check_shape_and_dtype(value_arr, [B, S, K, H], key_arr.dtype, 'value')
_check_shape_and_dtype(query_arr, [B, -1, -1, H], key_arr.dtype, 'query')
_check_shape_and_dtype(mask, [-1] * 4, np.dtype(bool), 'mask')
_check_shape_and_dtype(bias, [-1] * 4, None, 'bias')
_check_shape_and_dtype(query_seq_lengths, [B], np.dtype('int32'),
'query_seq_lengths')
_check_shape_and_dtype(key_value_seq_lengths, [B], np.dtype('int32'),
'key_value_seq_lengths')
if query_arr.shape[-2] % K != 0:
raise ValueError(f"The number of query heads must be a multiple of "
f"key/value heads, but got {query_arr.shape[-2]} vs {K}")
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
match implementation:
case 'xla':
out = _dot_product_attention_xla(
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
local_window_size=local_window_size,
)
case 'cudnn':
if bias is not None:
bias = check_valid_bias_batch(bias, query_arr.shape[-2])
bias = jnp.asarray(bias)
use_padding = (
query_seq_lengths is not None or key_value_seq_lengths is not None
)
if use_padding:
if query_seq_lengths is None:
T = query_arr.shape[1]
query_seq_lengths = jnp.full((B,), T, dtype=np.int32)
if key_value_seq_lengths is None:
key_value_seq_lengths = jnp.full((B,), S, dtype=np.int32)
mask_type = MaskType.NO_MASK
if use_padding and is_causal:
mask_type = MaskType.PADDING_CAUSAL
elif is_causal:
mask_type = MaskType.CAUSAL
elif use_padding:
mask_type = MaskType.PADDING
# CuDNN supports only the left window with an exclusive boundary when
# causal mask is enabled.
sliding_window = None
if local_window_size is not None:
l_window, r_window = local_window_size
if r_window == 0 or mask_type == MaskType.CAUSAL:
sliding_window = l_window + 1
else:
raise ValueError(f"cuDNN doesn't support right window: {r_window} "
"when causal mask is not used.")
out = cudnn_dot_product_attention(
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
key_value_seq_lengths, scale=scale_val, mask_type=mask_type,
sliding_window_length=sliding_window,
)
case None:
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
# best backend.
out = _dot_product_attention_xla(
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
local_window_size=local_window_size,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")
return jnp.reshape(out, output_shape)
def scaled_matmul(
lhs: Array,
rhs: Array,
lhs_scales: Array,
rhs_scales: Array,
preferred_element_type: DTypeLike = np.float32,
) -> Array:
r"""Scaled matrix multiplication function.
Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
The last dim is the contracting dim, and block size is inferred.
Mathematically, this operation is equivalent to::
a_block_size = a.shape[-1] // a_scales.shape[-1]
b_block_size = b.shape[-1] // b_scales.shape[-1]
a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1)
jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
Args:
lhs (Array): Operand a, shape (B, M, K).
rhs (Array): Operand b, shape (B, N, K).
lhs_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
rhs_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
Returns:
Array of shape (B, M, N).
Notes:
- We currently do not support user-defined `precision` for customizing the
compute data type. It is fixed to `jnp.float32`.
- Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`.
- To use cuDNN with Nvidia Blackwell GPUs, inputs must match::
# mxfp8
a, b: jnp.float8_e4m3fn | jnp.float8_e5m2
a_scales, b_scales: jnp.float8_e8m0fnu
block_size: 32
# nvfp4
a, b: jnp.float4_e2m1fn
a_scales, b_scales: jnp.float8_e4m3fn
block_size: 16
Examples:
Basic case:
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
>>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
>>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
>>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
>>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP
Array([[[8.]]], dtype=float32)
Using fused cuDNN call on Blackwell GPUs:
>>> dtype = jnp.float8_e4m3fn
>>> a = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64), dtype=dtype)
>>> b = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64), dtype=dtype)
>>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
>>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP
"""
a, b, a_scales, b_scales = lhs, rhs, lhs_scales, rhs_scales
if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)):
raise ValueError(
"scaled_matmul requires all inputs to be 3-dimensional arrays"
)
B_a, M_a, K_a = a.shape
B_b, N_b, K_b = b.shape
if K_a != K_b or B_a != B_b:
raise ValueError(
"scaled_matmul requires inputs a and b to have matching batch (B) "
f"and contract (K) dimensions, but got shapes {a.shape} and "
f"{b.shape}"
)
B_as, M_as, K_as = a_scales.shape
B_bs, N_bs, K_bs = b_scales.shape
if K_as != K_bs or B_as != B_bs:
raise ValueError(
"scaled_matmul requires scales to have matching batch (B) and "
f"contract (K) dimensions, but got shapes {a_scales.shape} and "
f"{b_scales.shape}"
)
if M_as != M_a or N_bs != N_b:
raise ValueError(
"scaled_matmul requires scales to match non-contract dimensions of "
f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: "
f"{a_scales.shape}, b_scales: {b_scales.shape}"
)
preferred_element_type = dtypes.canonicalize_dtype(
np.dtype(preferred_element_type)
)
out = cudnn_scaled_matmul(
a,
b,
a_scales,
b_scales,
preferred_element_type=preferred_element_type,
)
return out
def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'],
global_scale: Array | None = None):
r"""Get quantization configs for scaled_dot_general.
Create quantization configs for the `jax.nn.scaled_dot_general`.
See Also:
- :func:`jax.nn.scaled_dot_general`: Scaled dot general function.
"""
if mode == 'nvfp4':
one = jnp.ones((1,), dtype=np.float32)
return BlockScaleConfig(
mode='nvfp4',
block_size=16,
data_type=dtypes.float4_e2m1fn,
scale_type=dtypes.float8_e4m3fn,
global_scale=one if global_scale is None else global_scale,
infer_only=False
)
elif mode == 'mxfp8':
return BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=dtypes.float8_e4m3fn,
scale_type=dtypes.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
else:
raise ValueError(f"Unsupported mode: {mode}")
def scaled_dot_general(
lhs, rhs,
dimension_numbers,
preferred_element_type=np.float32,
configs: list[BlockScaleConfig] | None = None,
implementation: Literal['cudnn'] | None = None,
):
r"""Scaled dot general operation.
Performs a generalized dot product with block-scaled quantization on the
lhs and rhs inputs. This operation extends `lax.dot_general` to support
user-defined scaling configurations.
Essentially, the operation follows::
a, a_scales = quantize(lhs, configs[0])
b, b_scales = quantize(rhs, configs[1])
c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
Args:
lhs (ArrayLike): Input array.
rhs (ArrayLike): Input array.
dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying
the contraction and batch dimensions:
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
preferred_element_type (DTypeLike, optional): Output data type of the dot
product. Defaults to `jnp.float32`. Other valid types include
`jnp.bfloat16` and `jnp.float16`.
configs (list of BlockScaleConfig, optional): Scaling configurations for
lhs, rhs, and gradients. Users can obtain valid configurations via
`jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
are supported. If `None`, falls back to `lax.dot_general`.
implementation: str
(Deprecated) Backend selector, now ignored. The system chooses the backend
automatically. Scheduled for removal in future releases.
Returns:
Array: The resulting tensor, with batch dimensions first, followed by
non-contracting/non-batch dimensions of lhs, and then those of rhs.
See Also:
- :func:`jax.nn.scaled_matmul`: Scaled matmul function.
- :func:`jax.lax.dot_general`: General dot product operator.
Notes:
- Unlike `nn.scaled_matmul`, which assumes quantized low-precision
inputs with explicit scaling factors, this operator takes high-precision
inputs, applies quantization internally, and handles the backward pass.
Examples:
Creating config for mxfp8:
>>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3
Creating config for nvfp4:
>>> global_scale = jnp.array([0.5], jnp.float32)
>>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3
Using scaled_dot_general with the configs:
>>> import functools
>>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
>>> lhs = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64))
>>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64))
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) # doctest: +SKIP
"""
if implementation is not None:
warnings.warn("Backend selector, now ignored. The system chooses the "
"backend automatically.", DeprecationWarning)
if configs is None:
return lax.dot_general(lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type)
out = cudnn_scaled_dot_general(
lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs
)
return out
@custom_derivatives.custom_jvp
@api.jit
def log1mexp(x: ArrayLike) -> Array:
r"""Numerically stable calculation of :math:`\log(1 - \exp(-x))`.
This function is undefined for :math:`x < 0`.
Based on `TensorFlow's implementation <https://www.tensorflow.org/probability/api_docs/python/tfp/math/log1mexp>`_.
References:
.. [1] Martin Mächler. `Accurately Computing log(1 − exp(−|a|)) Assessed by the Rmpfr package.
<https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf>`_.
"""
numpy_util.check_arraylike("log1mexp", x)
x = jnp.asarray(x)
c = jnp.log(2.0)
return jnp.where(
x < c,
jnp.log(-jnp.expm1(-x)),
jnp.log1p(-jnp.exp(-x)),
)
log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x))