Welcome to Batch Tensorsolve documentation!¶
Installation & Usage
Project Info
Batch Tensorsolve¶
Documentation: https://batch-tensorsolve.readthedocs.io
Source Code: https://github.com/34j/batch-tensorsolve
Batched tensorsolve() for NumPy / PyTorch / JAX. (numpy/numpy#28099)
Installation¶
Install this via pip (or your favourite package manager):
pip install batch-tensorsolve
Usage¶
import numpy as np
from numpy.testing import assert_allclose
from batch_tensorsolve import btensorsolve
a = np.random.randn(2, 2, 3, 6)
b = np.random.randn(2, 2, 3)
assert_allclose(np.einsum("...ijk,...k->...ij", a, btensorsolve(a, b)), b)
Advanced Usage¶
It is recommended to explicitly specify the batch axes, as the desired shape will be ambiguous if axes of size 1 are present.
import numpy as np
from batch_tensorsolve import btensorsolve
a = np.random.randn(2, 1, 2, 2)
b = np.random.randn(2, 1, 2)
# 2 possibilities:
assert btensorsolve(a, b, num_batch_axes=1).shape == (2, 2) # 1st axis is batch
assert btensorsolve(a, b, num_batch_axes=2).shape == (2, 1, 2) # 1st and 2nd axes are batch
Broadcasting-like behavior is also supported:
import numpy as np
from numpy.testing import assert_allclose
from batch_tensorsolve import btensorsolve
a = np.random.randn(1, 2, 3, 6) # -> (2, 2, 3, 6)
b = np.random.randn(2, 1, 1) # -> (2, 2, 3)
left = np.einsum("...ijk,...k->...ij", a, btensorsolve(a, b))
assert_allclose(left, np.broadcast_to(b, left.shape))
Note that broadcasting (repeating) a for non-batch axes will result in numpy.linalg.LinAlgError: Singular matrix because the matrix representation of a has duplicate rows.
Contributors ✨¶
Thanks goes to these wonderful people (emoji key):
This project follows the all-contributors specification. Contributions of any kind welcome!
Credits¶
This package was created with Copier and the browniebroke/pypackage-template project template.