batch_tensorsolve package

exception batch_tensorsolve.AmbiguousBatchAxesWarning[source]

Bases: RuntimeWarning

batch_tensorsolve.broadcast_without_repeating(*arrays: Array, check_same_ndim: bool = False) tuple[Array, ...][source]

Broadcast arrays without repeating the data.

Parameters:
  • arrays (TArray) – The arrays to broadcast.

  • check_same_ndim (bool, optional) – Whether to check if all arrays have the same number of dimensions. Default is False.

Returns:

The broadcasted arrays.

Return type:

tuple[TArray]

batch_tensorsolve.btensorsolve(a: Array, b: Array, /, *, num_batch_axes: int | None = None, solve: Callable[[Array, Array], Array] | None = None) Array[source]

Solve the tensor equation a x = b for x.

It is assumed that all indices of x are summed over in the product, together with the rightmost indices of a, as is done in, for example, tensordot(a, x, axes=x.ndim).

Parameters:
  • a (array_like) – Coefficient tensor, of shape b.shape + Q. Q, a tuple, equals the shape of that sub-tensor of a consisting of the appropriate number of its rightmost indices, and must be such that there exists i that prod(Q) == prod(b.shape[i:])

  • b (array_like) – Right-hand tensor, which can be of any shape.

  • solve (Callable[[Array, Array], Array], optional) – A function that solves a batch of linear systems. It must have the same signature as array_api.linalg.solve. If None (default), the default array_api.linalg.solve is used.

  • num_batch_axes (int, optional) –

    The number of batch dimensions. If None (default), the number of batch dimensions is inferred from the shapes of a and b.

    Let shape = np.broadcast_shapes(a.shape[:b.ndim], b.shape) + a.shape[b.ndim:].

    It is recommended to specify this argument, as the inference might be wrong if there exists i >= 0, j > 0 that prod(shape[:i]) == prod(shape[i+j:]). (j + 1 possibilities)

    For example, if a has shape (3, 1, 1, 2, 2) and b has shape (3, 1, 1, 2), it is possible that - axis 1 is the batch axes and desired output shape is (3, 2) - axis 1, 2 are the batch axes and desired output shape is (3, 1, 2) - axis 1, 2, 3 are the batch axes and desired output shape is (3, 1, 1, 2)

Returns:

x

Return type:

ndarray, shape Q

Raises:

LinAlgError – If a is singular or not ‘square’ (in the above sense).

Warning

AmbiguousBatchAxesWarning

If the number of batch axes cannot be inferred from the shapes of a and b, and num_batch_axes is not specified.

See also

numpy.tensordot, tensorinv, numpy.einsum

Examples

>>> import numpy as np
>>> rng = np.random.default_rng()
>>> a = rng.normal(size=(2, 2*3, 4, 2, 3, 4))
>>> b = rng.normal(size=(2, 2*3, 4))
>>> x = np.linalg.tensorsolve(a, b)
>>> x.shape
(2, 2, 3, 4)
>>> np.allclose(np.einsum('...ijklm,...klm->...ij', a, x), b)
True