Welcome to Batch Tensorsolve documentation!

Installation & Usage

Batch Tensorsolve

CI Status Documentation Status Test coverage percentage

uv Ruff pre-commit

PyPI Version Supported Python versions License


Documentation: https://batch-tensorsolve.readthedocs.io

Source Code: https://github.com/34j/batch-tensorsolve


Batched tensorsolve() for NumPy / PyTorch / JAX. (numpy/numpy#28099)

Installation

Install this via pip (or your favourite package manager):

pip install batch-tensorsolve

Usage

import numpy as np
from numpy.testing import assert_allclose

from batch_tensorsolve import btensorsolve

a = np.random.randn(2, 2, 3, 6)
b = np.random.randn(2, 2, 3)
assert_allclose(np.einsum("...ijk,...k->...ij", a, btensorsolve(a, b)), b)

Advanced Usage

It is recommended to explicitly specify the batch axes, as the desired shape will be ambiguous if axes of size 1 are present.

import numpy as np

from batch_tensorsolve import btensorsolve

a = np.random.randn(2, 1, 2, 2)
b = np.random.randn(2, 1, 2)
# 2 possibilities:
assert btensorsolve(a, b, num_batch_axes=1).shape == (2, 2) # 1st axis is batch
assert btensorsolve(a, b, num_batch_axes=2).shape == (2, 1, 2) # 1st and 2nd axes are batch

Broadcasting-like behavior is also supported:

import numpy as np
from numpy.testing import assert_allclose

from batch_tensorsolve import btensorsolve

a = np.random.randn(1, 2, 3, 6) # -> (2, 2, 3, 6)
b = np.random.randn(2, 1, 1) # -> (2, 2, 3)
left = np.einsum("...ijk,...k->...ij", a, btensorsolve(a, b))
assert_allclose(left, np.broadcast_to(b, left.shape))

Note that broadcasting (repeating) a for non-batch axes will result in numpy.linalg.LinAlgError: Singular matrix because the matrix representation of a has duplicate rows.

Contributors ✨

Thanks goes to these wonderful people (emoji key):

This project follows the all-contributors specification. Contributions of any kind welcome!

Credits

Copier

This package was created with Copier and the browniebroke/pypackage-template project template.