import warnings
from collections.abc import Callable
from math import prod
import array_api_extra as xpx
from array_api._2024_12 import Array
from array_api_compat import array_namespace
[docs]
class AmbiguousBatchAxesWarning(RuntimeWarning):
pass
[docs]
def broadcast_without_repeating(
*arrays: Array, check_same_ndim: bool = False
) -> tuple[Array, ...]:
"""
Broadcast arrays without repeating the data.
Parameters
----------
arrays : TArray
The arrays to broadcast.
check_same_ndim : bool, optional
Whether to check if all arrays have the same number of dimensions.
Default is False.
Returns
-------
tuple[TArray]
The broadcasted arrays.
"""
xp = array_namespace(*arrays)
arrays_ = tuple(xp.asarray(a) for a in arrays)
xp.broadcast_shapes(*[a.shape for a in arrays_])
if check_same_ndim:
if len({a.ndim for a in arrays_}) != 1:
raise ValueError(
"All arrays must have the same number of dimensions, "
f"but got {tuple(a.ndim for a in arrays_)}"
)
return arrays_
max_dim = max(a.ndim for a in arrays_)
return tuple(array[(None,) * (max_dim - array.ndim) + (...,)] for array in arrays_)
[docs]
def btensorsolve(
a: Array,
b: Array,
/,
*,
num_batch_axes: int | None = None,
solve: Callable[[Array, Array], Array] | None = None,
) -> Array:
"""
Solve the tensor equation ``a x = b`` for x.
It is assumed that all indices of `x` are summed over in the product,
together with the rightmost indices of `a`, as is done in, for example,
``tensordot(a, x, axes=x.ndim)``.
Parameters
----------
a : array_like
Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
the shape of that sub-tensor of `a` consisting of the appropriate
number of its rightmost indices, and must be such that
**there exists i that**
``prod(Q) == prod(b.shape[i:])``
b : array_like
Right-hand tensor, which can be of any shape.
solve : Callable[[Array, Array], Array], optional
A function that solves a batch of linear systems. It must have the
same signature as `array_api.linalg.solve`. If None (default), the
default `array_api.linalg.solve` is used.
num_batch_axes : int, optional
The number of batch dimensions. If None (default), the number of batch
dimensions is inferred from the shapes of `a` and `b`.
Let ``shape = np.broadcast_shapes(a.shape[:b.ndim],
b.shape) + a.shape[b.ndim:]``.
It is recommended to specify this argument, as the inference
might be wrong if there exists i >= 0, j > 0 that
``prod(shape[:i]) == prod(shape[i+j:])``. (j + 1 possibilities)
For example, if `a` has shape (3, 1, 1, 2, 2) and `b` has shape
(3, 1, 1, 2), it is possible that
- axis 1 is the batch axes and desired output shape is (3, 2)
- axis 1, 2 are the batch axes and desired output shape is (3, 1, 2)
- axis 1, 2, 3 are the batch axes and desired output shape is (3, 1, 1, 2)
Returns
-------
x : ndarray, shape Q
Raises
------
LinAlgError
If `a` is singular or not 'square' (in the above sense).
Warnings
--------
AmbiguousBatchAxesWarning
If the number of batch axes cannot be inferred from the shapes of
`a` and `b`, and `num_batch_axes` is not specified.
See Also
--------
numpy.tensordot, tensorinv, numpy.einsum
Examples
--------
>>> import numpy as np
>>> rng = np.random.default_rng()
>>> a = rng.normal(size=(2, 2*3, 4, 2, 3, 4))
>>> b = rng.normal(size=(2, 2*3, 4))
>>> x = np.linalg.tensorsolve(a, b)
>>> x.shape
(2, 2, 3, 4)
>>> np.allclose(np.einsum('...ijklm,...klm->...ij', a, x), b)
True
"""
# https://github.com/numpy/numpy/blob/
# e7a123b2d3eca9897843791dd698c1803d9a39c2/numpy/linalg/_linalg.py#L291
xp = array_namespace(a, b)
a_ = xp.asarray(a)
b_ = xp.asarray(b)
# find right dimensions
# a = [1, 1, 2, 3, 2 (dim1) 2 2 3 (dim2) 2 6]
# b = [2, 3, 1, 1, 2 (dim1) 2 1 3 (dim2)]
axis_sol_last = b_.ndim
shape = xpx.broadcast_shapes(a_.shape[:axis_sol_last], b_.shape)
# the dimension of the linear system
sol_size = int(prod(a_.shape[axis_sol_last:]))
# assume num_batch_axes
if num_batch_axes is None:
sol_size_current = 1
for num_batch_axes in range(axis_sol_last - 1, -1, -1):
sol_size_current *= shape[num_batch_axes]
if sol_size_current == sol_size:
break
else:
raise ValueError(
"Unable to divide batch dimensions and solution dimensions"
)
if num_batch_axes > 0 and shape[num_batch_axes - 1] == 1:
warnings.warn(
"It is impossible to infer the number of "
"batch axes from the shapes of `a` and `b`. "
"Consider specifying `num_batch_axes` explicitly.",
AmbiguousBatchAxesWarning,
stacklevel=2,
)
# a must not be repeated
if a_.shape[:axis_sol_last][num_batch_axes:] != shape[num_batch_axes:]:
raise ValueError("Non-batch axes of `a` must not be repeated.")
sol_shape_last = a_.shape[axis_sol_last:]
# a = [1, 1, 2, 3, 2 (dim1) 2 2 3 (dim2) 2 6]
# b = [2, 3, 1, 1, 2 (dim1) 2 2 3 (dim2)]
b_ = xp.broadcast_to(b_, b_.shape[:num_batch_axes] + shape[num_batch_axes:])
# split batch shape into two parts
batch_shape = shape[:num_batch_axes]
batch_shape_b_idx = tuple(
i for i in range(num_batch_axes) if a_.shape[i] == 1 and b_.shape[i] > 1
)
batch_b_count = len(batch_shape_b_idx)
# batch_common_count = num_batch_axes - batch_b_count
batch_b_shape = tuple(
batch_shape[i] for i in range(num_batch_axes) if i in batch_shape_b_idx
)
batch_common_shape = tuple(
batch_shape[i] for i in range(num_batch_axes) if i not in batch_shape_b_idx
)
batch_common_shape_b = tuple(
b_.shape[i] for i in range(num_batch_axes) if i not in batch_shape_b_idx
)
batch_b_size = prod(batch_b_shape)
# a = [1, 1, 2, 3, 2 (dim1) 2 2 3 (dim2) 2 6] -> [2, 3, 2| 2, 2, 3| 2, 6]
# b = [2, 3, 1, 1, 2 (dim1) 2 2 3 (dim2)] -> [1, 1, 2| 2, 2, 3| 2, 3]
# a does not require moveaxis
a_ = a_[
tuple(
0
if i in batch_shape_b_idx
else slice(
None,
)
for i in range(num_batch_axes)
)
]
b_ = xp.moveaxis(b_, batch_shape_b_idx, tuple(range(-batch_b_count, 0)))
# a should not be repeated
# flatten last 2 axes
a_ = xp.reshape(a_, (*batch_common_shape, sol_size, sol_size))
b_ = xp.reshape(b_, (*batch_common_shape_b, sol_size, batch_b_size))
# solve
solve_ = solve if solve is not None else xp.linalg.solve
x = solve_(a_, b_)
# reshape back
x = xp.reshape(x, batch_common_shape + sol_shape_last + batch_b_shape)
x = xp.moveaxis(
x,
tuple(range(-batch_b_count, 0)),
batch_shape_b_idx,
)
return x