Import tensorflow

This commit is contained in:
2026-02-15 21:45:42 -08:00
parent f3e8b90764
commit c530630153
20524 changed files with 9017694 additions and 25 deletions
@@ -0,0 +1,462 @@
from typing import Set
import pytest
from opt_einsum import backends, contract, contract_expression, sharing
from opt_einsum.contract import ArrayShaped, infer_backend, parse_backend
from opt_einsum.testing import build_views
try:
# needed so tensorflow doesn't allocate all gpu mem
try:
from tensorflow import ConfigProto # type: ignore
from tensorflow import Session as TFSession
except ImportError:
from tensorflow.compat.v1 import ConfigProto # type: ignore
from tensorflow.compat.v1 import Session as TFSession
_TF_CONFIG = ConfigProto()
_TF_CONFIG.gpu_options.allow_growth = True
except ImportError:
pass
tests = [
"ab,bc->ca",
"abc,bcd,dea",
"abc,def->fedcba",
"abc,bcd,df->fa",
# test 'prefer einsum' ops
"ijk,ikj",
"i,j->ij",
"ijk,k->ij",
"AB,BC->CA",
]
@pytest.mark.parametrize("string", tests)
def test_tensorflow(string: str) -> None:
np = pytest.importorskip("numpy")
pytest.importorskip("tensorflow")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
opt = np.empty_like(ein)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
sess = TFSession(config=_TF_CONFIG)
with sess.as_default():
expr(*views, backend="tensorflow", out=opt)
sess.close()
assert np.allclose(ein, opt)
# test non-conversion mode
tensorflow_views = [backends.to_tensorflow(view) for view in views]
expr(*tensorflow_views)
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_tensorflow_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy")
tf = pytest.importorskip("tensorflow")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
expr = contract_expression(eq, *ops, constants=constants)
# check tensorflow
with TFSession(config=_TF_CONFIG).as_default():
res_got = expr(var, backend="tensorflow")
assert all(
array is None or infer_backend(array) == "tensorflow" for array in expr._evaluated_constants["tensorflow"]
)
assert np.allclose(res_exp, res_got)
# check can call with numpy still
res_got2 = expr(var, backend="numpy")
assert np.allclose(res_exp, res_got2)
# check tensorflow call returns tensorflow still
res_got3 = expr(backends.to_tensorflow(var))
assert isinstance(res_got3, tf.Tensor)
@pytest.mark.parametrize("string", tests)
def test_tensorflow_with_sharing(string: str) -> None:
np = pytest.importorskip("numpy")
tf = pytest.importorskip("tensorflow")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
sess = TFSession(config=_TF_CONFIG)
with sess.as_default(), sharing.shared_intermediates() as cache:
tfl1 = expr(*views, backend="tensorflow")
assert sharing.get_sharing_cache() is cache
cache_sz = len(cache)
assert cache_sz > 0
tfl2 = expr(*views, backend="tensorflow")
assert len(cache) == cache_sz
assert all(isinstance(t, tf.Tensor) for t in cache.values())
assert np.allclose(ein, tfl1)
assert np.allclose(ein, tfl2)
@pytest.mark.parametrize("string", tests)
def test_theano(string: str) -> None:
np = pytest.importorskip("numpy")
theano = pytest.importorskip("theano")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
opt = expr(*views, backend="theano")
assert np.allclose(ein, opt)
# test non-conversion mode
theano_views = [backends.to_theano(view) for view in views]
theano_opt = expr(*theano_views)
assert isinstance(theano_opt, theano.tensor.TensorVariable)
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_theano_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy")
theano = pytest.importorskip("theano")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
expr = contract_expression(eq, *ops, constants=constants)
# check theano
res_got = expr(var, backend="theano")
assert all(array is None or infer_backend(array) == "theano" for array in expr._evaluated_constants["theano"])
assert np.allclose(res_exp, res_got)
# check can call with numpy still
res_got2 = expr(var, backend="numpy")
assert np.allclose(res_exp, res_got2)
# check theano call returns theano still
res_got3 = expr(backends.to_theano(var))
assert isinstance(res_got3, theano.tensor.TensorVariable)
@pytest.mark.parametrize("string", tests)
def test_theano_with_sharing(string: str) -> None:
np = pytest.importorskip("numpy")
theano = pytest.importorskip("theano")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
with sharing.shared_intermediates() as cache:
thn1 = expr(*views, backend="theano")
assert sharing.get_sharing_cache() is cache
cache_sz = len(cache)
assert cache_sz > 0
thn2 = expr(*views, backend="theano")
assert len(cache) == cache_sz
assert all(isinstance(t, theano.tensor.TensorVariable) for t in cache.values())
assert np.allclose(ein, thn1)
assert np.allclose(ein, thn2)
@pytest.mark.parametrize("string", tests)
def test_cupy(string: str) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
cupy = pytest.importorskip("cupy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
opt = expr(*views, backend="cupy")
assert np.allclose(ein, opt)
# test non-conversion mode
cupy_views = [backends.to_cupy(view) for view in views]
cupy_opt = expr(*cupy_views)
assert isinstance(cupy_opt, cupy.ndarray)
assert np.allclose(ein, cupy.asnumpy(cupy_opt))
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_cupy_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
cupy = pytest.importorskip("cupy")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
expr = contract_expression(eq, *ops, constants=constants)
# check cupy
res_got = expr(var, backend="cupy")
# check cupy versions of constants exist
assert all(array is None or infer_backend(array) == "cupy" for array in expr._evaluated_constants["cupy"])
assert np.allclose(res_exp, res_got)
# check can call with numpy still
res_got2 = expr(var, backend="numpy")
assert np.allclose(res_exp, res_got2)
# check cupy call returns cupy still
res_got3 = expr(cupy.asarray(var))
assert isinstance(res_got3, cupy.ndarray)
assert np.allclose(res_exp, res_got3.get())
@pytest.mark.parametrize("string", tests)
def test_jax(string: str) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
pytest.importorskip("jax")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
opt = expr(*views, backend="jax")
assert np.allclose(ein, opt)
assert isinstance(opt, np.ndarray)
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_jax_with_constants(constants: Set[int]) -> None:
jax = pytest.importorskip("jax")
key = jax.random.PRNGKey(42)
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [jax.random.uniform(key, shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = jax.random.uniform(key, shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
expr = contract_expression(eq, *ops, constants=constants)
# check jax
res_got = expr(var, backend="jax")
# check jax versions of constants exist
assert all(array is None or infer_backend(array).startswith("jax") for array in expr._evaluated_constants["jax"])
assert jax.numpy.sum(jax.numpy.abs(res_exp - res_got)) < 1e-8
def test_jax_jit_gradient() -> None:
jax = pytest.importorskip("jax")
key = jax.random.PRNGKey(42)
eq = "ij,jk,kl->"
shapes = (2, 3), (3, 4), (4, 2)
views = [jax.random.uniform(key, s) for s in shapes]
expr = contract_expression(eq, *shapes)
x0 = expr(*views)
jit_expr = jax.jit(expr)
x1 = jit_expr(*views).item()
assert x1 == pytest.approx(x0, rel=1e-5)
# jax only takes gradient w.r.t first argument
grad_expr = jax.jit(jax.grad(lambda views: expr(*views)))
view_grads = grad_expr(views)
assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))
# taking a step along the gradient should reduce our 'loss'
new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
x2 = jit_expr(*new_views).item()
assert x2 < x1
def test_autograd_gradient() -> None:
np = pytest.importorskip("numpy")
autograd = pytest.importorskip("autograd")
eq = "ij,jk,kl->"
shapes = (2, 3), (3, 4), (4, 2)
views = [np.random.randn(*s) for s in shapes]
expr = contract_expression(eq, *shapes)
x0 = expr(*views)
# autograd only takes gradient w.r.t first argument
grad_expr = autograd.grad(lambda views: expr(*views))
view_grads = grad_expr(views)
assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))
# taking a step along the gradient should reduce our 'loss'
new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
x1 = expr(*new_views)
assert x1 < x0
@pytest.mark.parametrize("string", tests)
def test_dask(string: str) -> None:
np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
# test non-conversion mode
da_views = [da.from_array(x, chunks=(2)) for x in views]
da_opt = expr(*da_views)
# check type is maintained when not using numpy arrays
assert isinstance(da_opt, da.Array)
assert np.allclose(ein, np.array(da_opt))
# try raw contract
da_opt = contract(string, *da_views)
assert isinstance(da_opt, da.Array)
assert np.allclose(ein, np.array(da_opt))
@pytest.mark.parametrize("string", tests)
def test_sparse(string: str) -> None:
np = pytest.importorskip("numpy")
sparse = pytest.importorskip("sparse")
views = build_views(string)
# sparsify views so they don't become dense during contraction
for view in views:
np.random.seed(42)
mask = np.random.choice([False, True], view.shape, True, [0.05, 0.95])
view[mask] = 0
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
# test non-conversion mode
sparse_views = [sparse.COO.from_numpy(x) for x in views]
sparse_opt = expr(*sparse_views)
# If the expression returns a float, stop here
if not ein.shape:
assert pytest.approx(ein) == 0.0
return
# check type is maintained when not using numpy arrays
assert isinstance(sparse_opt, sparse.COO)
assert np.allclose(ein, sparse_opt.todense())
# try raw contract
sparse_opt = contract(string, *sparse_views)
assert isinstance(sparse_opt, sparse.COO)
assert np.allclose(ein, sparse_opt.todense())
@pytest.mark.parametrize("string", tests)
def test_torch(string: str) -> None:
torch = pytest.importorskip("torch")
views = build_views(string, array_function=torch.rand)
ein = torch.einsum(string, *views)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
opt = expr(*views, backend="torch")
torch.testing.assert_close(ein, opt)
# test non-conversion mode
torch_views = [backends.to_torch(view) for view in views]
torch_opt = expr(*torch_views)
assert isinstance(torch_opt, torch.Tensor)
torch.testing.assert_close(ein, torch_opt)
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_torch_with_constants(constants: Set[int]) -> None:
torch = pytest.importorskip("torch")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [torch.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = torch.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)), backend="torch")
expr = contract_expression(eq, *ops, constants=constants)
# check torch
res_got = expr(var, backend="torch")
assert all(array is None or infer_backend(array) == "torch" for array in expr._evaluated_constants["torch"])
torch.testing.assert_close(res_exp, res_got)
# check can call with numpy still
res_got2 = expr(var, backend="torch")
torch.testing.assert_close(res_exp, res_got2)
# check torch call returns torch still
res_got3 = expr(backends.to_torch(var))
assert isinstance(res_got3, torch.Tensor)
torch.testing.assert_close(res_exp, res_got3)
def test_auto_backend_custom_array_no_tensordot() -> None:
x = ArrayShaped((1, 2, 3))
# Shaped is an array-like object defined by opt_einsum - which has no TDOT
assert infer_backend(x) == "opt_einsum"
assert parse_backend([x], "auto") == "numpy"
assert parse_backend([x], None) == "numpy"
@pytest.mark.parametrize("string", tests)
def test_object_arrays_backend(string: str) -> None:
np = pytest.importorskip("numpy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
assert ein.dtype != object
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
obj_views = [view.astype(object) for view in views]
# try raw contract
obj_opt = contract(string, *obj_views, backend="object")
assert obj_opt.dtype == object
assert np.allclose(ein, obj_opt.astype(float))
# test expression
obj_opt = expr(*obj_views, backend="object")
assert obj_opt.dtype == object
assert np.allclose(ein, obj_opt.astype(float))
@@ -0,0 +1,81 @@
"""
Tests the BLAS capability for the opt_einsum module.
"""
from typing import Any
import pytest
from opt_einsum import blas, contract
blas_tests = [
# DOT
((["k", "k"], "", set("k")), "DOT"), # DDOT
((["ijk", "ijk"], "", set("ijk")), "DOT"), # DDOT
# GEMV?
# GEMM
((["ij", "jk"], "ik", set("j")), "GEMM"), # GEMM N N
((["ijl", "jlk"], "ik", set("jl")), "GEMM"), # GEMM N N Tensor
((["ij", "kj"], "ik", set("j")), "GEMM"), # GEMM N T
((["ijl", "kjl"], "ik", set("jl")), "GEMM"), # GEMM N T Tensor
((["ji", "jk"], "ik", set("j")), "GEMM"), # GEMM T N
((["jli", "jlk"], "ik", set("jl")), "GEMM"), # GEMM T N Tensor
((["ji", "kj"], "ik", set("j")), "GEMM"), # GEMM T T
((["jli", "kjl"], "ik", set("jl")), "GEMM"), # GEMM T T Tensor
# GEMM with final transpose
((["ij", "jk"], "ki", set("j")), "GEMM"), # GEMM N N
((["ijl", "jlk"], "ki", set("jl")), "GEMM"), # GEMM N N Tensor
((["ij", "kj"], "ki", set("j")), "GEMM"), # GEMM N T
((["ijl", "kjl"], "ki", set("jl")), "GEMM"), # GEMM N T Tensor
((["ji", "jk"], "ki", set("j")), "GEMM"), # GEMM T N
((["jli", "jlk"], "ki", set("jl")), "GEMM"), # GEMM T N Tensor
((["ji", "kj"], "ki", set("j")), "GEMM"), # GEMM T T
((["jli", "kjl"], "ki", set("jl")), "GEMM"), # GEMM T T Tensor
# Tensor Dot (requires copy), lets not deal with this for now
((["ilj", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM N N Tensor
((["ijl", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM N N Tensor
((["ilj", "kjl"], "ik", set("jl")), "TDOT"), # FT GEMM N T Tensor
((["ijl", "klj"], "ik", set("jl")), "TDOT"), # ST GEMM N T Tensor
((["lji", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM T N Tensor
((["jli", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM T N Tensor
((["lji", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM T N Tensor
((["jli", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM T N Tensor
# Tensor Dot (requires copy), lets not deal with this for now with transpose
((["ilj", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM N N Tensor
((["ijl", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM N N Tensor
((["ilj", "kjl"], "ik", set("lj")), "TDOT"), # FT GEMM N T Tensor
((["ijl", "klj"], "ik", set("lj")), "TDOT"), # ST GEMM N T Tensor
((["lji", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM T N Tensor
((["jli", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM T N Tensor
((["lji", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM T N Tensor
((["jli", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM T N Tensor
# Other
((["ijk", "ikj"], "", set("ijk")), "DOT/EINSUM"), # Transpose DOT
((["i", "j"], "ij", set()), "OUTER/EINSUM"), # Outer
((["ijk", "ik"], "j", set("ik")), "GEMV/EINSUM"), # Matrix-vector
((["ijj", "jk"], "ik", set("j")), False), # Double index
((["ijk", "j"], "ij", set()), False), # Index sum 1
((["ij", "ij"], "ij", set()), False), # Index sum 2
]
@pytest.mark.parametrize("inp,benchmark", blas_tests)
def test_can_blas(inp: Any, benchmark: bool) -> None:
result = blas.can_blas(*inp)
assert result == benchmark
def test_blas_out() -> None:
np = pytest.importorskip("numpy")
a = np.random.rand(4, 4)
b = np.random.rand(4, 4)
c = np.random.rand(4, 4)
d = np.empty((4, 4))
contract("ij,jk->ik", a, b, out=d)
np.testing.assert_allclose(d, np.dot(a, b))
assert np.allclose(d, np.dot(a, b))
contract("ij,jk,kl->il", a, b, c, out=d)
np.testing.assert_allclose(d, np.dot(a, b).dot(c))
@@ -0,0 +1,279 @@
"""
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
"""
from typing import Any, List
import pytest
from opt_einsum import contract, contract_expression, contract_path
from opt_einsum.paths import _PATH_OPTIONS, linear_to_ssa, ssa_to_linear
from opt_einsum.testing import build_views, rand_equation
from opt_einsum.typing import OptimizeKind
# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")
tests = [
# Test scalar-like operations
"a,->a",
"ab,->ab",
",ab,->ab",
",,->",
# Test hadamard-like products
"a,ab,abc->abc",
"a,b,ab->ab",
# Test index-transformations
"ea,fb,gc,hd,abcd->efgh",
"ea,fb,abcd,gc,hd->efgh",
"abcd,ea,fb,gc,hd->efgh",
# Test complex contractions
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
"cd,bdhe,aidb,hgca,gc,hgibcd,hgac",
"abhe,hidj,jgba,hiab,gab",
"bde,cdh,agdb,hica,ibd,hgicd,hiac",
"chd,bde,agbc,hiad,hgc,hgi,hiad",
"chd,bde,agbc,hiad,bdi,cgh,agdb",
"bdhe,acad,hiab,agac,hibd",
# Test collapse
"ab,ab,c->",
"ab,ab,c->c",
"ab,ab,cd,cd->",
"ab,ab,cd,cd->ac",
"ab,ab,cd,cd->cd",
"ab,ab,cd,cd,ef,ef->",
# Test outer prodcuts
"ab,cd,ef->abcdef",
"ab,cd,ef->acdf",
"ab,cd,de->abcde",
"ab,cd,de->be",
"ab,bcd,cd->abcd",
"ab,bcd,cd->abd",
# Random test cases that have previously failed
"eb,cb,fb->cef",
"dd,fb,be,cdb->cef",
"bca,cdb,dbf,afc->",
"dcc,fce,ea,dbf->ab",
"fdf,cdd,ccd,afe->ae",
"abcd,ad",
"ed,fcd,ff,bcf->be",
"baa,dcf,af,cde->be",
"bd,db,eac->ace",
"fff,fae,bef,def->abd",
"efc,dbc,acf,fd->abe",
# Inner products
"ab,ab",
"ab,ba",
"abc,abc",
"abc,bac",
"abc,cba",
# GEMM test cases
"ab,bc",
"ab,cb",
"ba,bc",
"ba,cb",
"abcd,cd",
"abcd,ab",
"abcd,cdef",
"abcd,cdef->feba",
"abcd,efdc",
# Inner than dot
"aab,bc->ac",
"ab,bcc->ac",
"aab,bcc->ac",
"baa,bcc->ac",
"aab,ccb->ac",
# Randomly build test caes
"aab,fa,df,ecc->bde",
"ecb,fef,bad,ed->ac",
"bcf,bbb,fbf,fc->",
"bb,ff,be->e",
"bcb,bb,fc,fff->",
"fbb,dfd,fc,fc->",
"afd,ba,cc,dc->bf",
"adb,bc,fa,cfc->d",
"bbd,bda,fc,db->acf",
"dba,ead,cad->bce",
"aef,fbc,dca->bde",
]
@pytest.mark.parametrize("optimize", (True, False, None))
def test_contract_plain_types(optimize: OptimizeKind) -> None:
expr = "ij,jk,kl->il"
ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)]
path = contract_path(expr, *ops, optimize=optimize)
assert len(path) == 2
result = contract(expr, *ops, optimize=optimize)
assert result.shape == (2, 2)
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare(optimize: OptimizeKind, string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
opt = contract(string, *views, optimize=optimize, use_blas=False)
assert np.allclose(ein, opt)
@pytest.mark.parametrize("string", tests)
def test_drop_in_replacement(string: str) -> None:
views = build_views(string)
opt = contract(string, *views)
assert np.allclose(opt, np.einsum(string, *views))
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare_greek(optimize: OptimizeKind, string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
# convert to greek
string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string)
opt = contract(string, *views, optimize=optimize, use_blas=False)
assert np.allclose(ein, opt)
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare_blas(optimize: OptimizeKind, string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(string, *views, optimize=optimize)
assert np.allclose(ein, opt)
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare_blas_greek(optimize: OptimizeKind, string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False)
# convert to greek
string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string)
opt = contract(string, *views, optimize=optimize)
assert np.allclose(ein, opt)
def test_some_non_alphabet_maintains_order() -> None:
# 'c beta a' should automatically go to -> 'a c beta'
string = "c" + chr(ord("b") + 848) + "a"
# but beta will be temporarily replaced with 'b' for which 'cba->abc'
# so check manual output kicks in:
x = np.random.rand(2, 3, 4)
assert np.allclose(contract(string, x), contract("cxa", x))
def test_printing():
string = "bbd,bda,fc,db->acf"
views = build_views(string)
ein = contract_path(string, *views)
assert len(str(ein[1])) == 728
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
@pytest.mark.parametrize("use_blas", [False, True])
@pytest.mark.parametrize("out_spec", [False, True])
def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: bool, out_spec: bool) -> None:
views = build_views(string)
shapes = [view.shape if hasattr(view, "shape") else () for view in views]
expected = contract(string, *views, optimize=False, use_blas=False)
expr = contract_expression(string, *shapes, optimize=optimize, use_blas=use_blas)
if out_spec and ("->" in string) and (string[-2:] != "->"):
(out,) = build_views(string.split("->")[1])
expr(*views, out=out)
else:
out = expr(*views)
assert np.allclose(out, expected)
# check representations
assert string in expr.__repr__()
assert string in expr.__str__()
def test_contract_expression_interleaved_input() -> None:
x, y, z = (np.random.randn(2, 2) for _ in "xyz")
expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0])
xshp, yshp, zshp = ((2, 2) for _ in "xyz")
expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0])
out = expr(x, y, z)
assert np.allclose(out, expected)
@pytest.mark.parametrize(
"string,constants",
[
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
("bdef,cdkj,ji,ikeh,hbc,lfo", [0, 1, 2, 3]),
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
("ijab,acd,bce,df,ef->ji", [1, 2, 3, 4]),
("ab,cd,ad,cb", [1, 3]),
("ab,bc,cd", [0, 1]),
],
)
def test_contract_expression_with_constants(string: str, constants: List[int]) -> None:
views = build_views(string)
expected = contract(string, *views, optimize=False, use_blas=False)
shapes = [view.shape if hasattr(view, "shape") else () for view in views]
expr_args: List[Any] = []
ctrc_args = []
for i, (shape, view) in enumerate(zip(shapes, views)):
if i in constants:
expr_args.append(view)
else:
expr_args.append(shape)
ctrc_args.append(view)
expr = contract_expression(string, *expr_args, constants=constants)
out = expr(*ctrc_args)
assert np.allclose(expected, out)
@pytest.mark.parametrize("optimize", ["greedy", "optimal"])
@pytest.mark.parametrize("n", [4, 5])
@pytest.mark.parametrize("reg", [2, 3])
@pytest.mark.parametrize("n_out", [0, 2, 4])
@pytest.mark.parametrize("global_dim", [False, True])
def test_rand_equation(optimize: OptimizeKind, n: int, reg: int, n_out: int, global_dim: bool) -> None:
eq, _, size_dict = rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True)
views = build_views(eq, size_dict)
expected = contract(eq, *views, optimize=False)
actual = contract(eq, *views, optimize=optimize)
assert np.allclose(expected, actual)
@pytest.mark.parametrize("equation", tests)
def test_linear_vs_ssa(equation: str) -> None:
views = build_views(equation)
linear_path, _ = contract_path(equation, *views)
ssa_path = linear_to_ssa(linear_path)
linear_path2 = ssa_to_linear(ssa_path)
assert linear_path2 == linear_path
def test_contract_path_supply_shapes() -> None:
eq = "ab,bc,cd"
shps = [(2, 3), (3, 4), (4, 5)]
contract_path(eq, *shps, shapes=True)
@@ -0,0 +1,152 @@
"""
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
"""
from typing import Any, Tuple
import pytest
from opt_einsum import contract, contract_expression, contract_path
from opt_einsum.typing import PathType
# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")
def test_contract_expression_checks() -> None:
# check optimize needed
with pytest.raises(ValueError):
contract_expression("ab,bc->ac", (2, 3), (3, 4), optimize=False)
# check sizes are still checked
with pytest.raises(ValueError):
contract_expression("ab,bc->ac", (2, 3), (3, 4), (42, 42))
# check if out given
out = np.empty((2, 4))
with pytest.raises(ValueError):
contract_expression("ab,bc->ac", (2, 3), (3, 4), out=out)
# check still get errors when wrong ranks supplied to expression
expr = contract_expression("ab,bc->ac", (2, 3), (3, 4))
# too few arguments
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3))
assert "`ContractExpression` takes exactly 2" in str(err.value)
# too many arguments
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3), np.random.rand(2, 3), np.random.rand(2, 3))
assert "`ContractExpression` takes exactly 2" in str(err.value)
# wrong shapes
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3, 4), np.random.rand(3, 4))
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 4), np.random.rand(3, 4, 5))
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3), np.random.rand(3, 4), out=np.random.rand(2, 4, 6))
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
# should only be able to specify out
with pytest.raises(TypeError) as err_type:
expr(np.random.rand(2, 3), np.random.rand(3, 4), order="F") # type: ignore
assert "got an unexpected keyword" in str(err_type.value)
def test_broadcasting_contraction() -> None:
a = np.random.rand(1, 5, 4)
b = np.random.rand(4, 6)
c = np.random.rand(5, 6)
d = np.random.rand(10)
ein_scalar = contract("ijk,kl,jl", a, b, c, optimize=False)
opt_scalar = contract("ijk,kl,jl", a, b, c, optimize=True)
assert np.allclose(ein_scalar, opt_scalar)
result = ein_scalar * d
ein = contract("ijk,kl,jl,i->i", a, b, c, d, optimize=False)
opt = contract("ijk,kl,jl,i->i", a, b, c, d, optimize=True)
assert np.allclose(ein, result)
assert np.allclose(opt, result)
def test_broadcasting_contraction2() -> None:
a = np.random.rand(1, 1, 5, 4)
b = np.random.rand(4, 6)
c = np.random.rand(5, 6)
d = np.random.rand(7, 7)
ein_scalar = contract("abjk,kl,jl", a, b, c, optimize=False)
opt_scalar = contract("abjk,kl,jl", a, b, c, optimize=True)
assert np.allclose(ein_scalar, opt_scalar)
result = ein_scalar * d
ein = contract("abjk,kl,jl,ab->ab", a, b, c, d, optimize=False)
opt = contract("abjk,kl,jl,ab->ab", a, b, c, d, optimize=True)
assert np.allclose(ein, result)
assert np.allclose(opt, result)
def test_broadcasting_contraction3() -> None:
a = np.random.rand(1, 5, 4)
b = np.random.rand(4, 1, 6)
c = np.random.rand(5, 6)
d = np.random.rand(7, 7)
ein = contract("ajk,kbl,jl,ab->ab", a, b, c, d, optimize=False)
opt = contract("ajk,kbl,jl,ab->ab", a, b, c, d, optimize=True)
assert np.allclose(ein, opt)
def test_broadcasting_contraction4() -> None:
a = np.arange(64).reshape(2, 4, 8)
ein = contract("obk,ijk->ioj", a, a, optimize=False)
opt = contract("obk,ijk->ioj", a, a, optimize=True)
assert np.allclose(ein, opt)
def test_can_blas_on_healed_broadcast_dimensions() -> None:
expr = contract_expression("ab,bc,bd->acd", (5, 4), (1, 5), (4, 20))
# first contraction involves broadcasting
assert expr.contraction_list[0][2] == "bc,ab->bca"
assert expr.contraction_list[0][-1] is False
# but then is healed GEMM is usable
assert expr.contraction_list[1][2] == "bca,bd->acd"
assert expr.contraction_list[1][-1] == "GEMM"
def test_pathinfo_for_empty_contraction() -> None:
eq = "->"
arrays = (1.0,)
path: PathType = []
_, info = contract_path(eq, *arrays, optimize=path)
# some info is built lazily, so check repr
assert repr(info)
assert info.largest_intermediate == 1
@pytest.mark.parametrize(
"expression, operands",
[
[",,->", (5, 5.0, 2.0j)],
["ab,->", ([[5, 5], [2.0, 1]], 2.0j)],
["ab,bc->ac", ([[5, 5], [2.0, 1]], [[2.0, 1], [3.0, 4]])],
["ab,->", ([[5, 5], [2.0, 1]], True)],
],
)
def test_contract_with_assumed_shapes(expression: str, operands: Tuple[Any]) -> None:
"""Test that we can contract with assumed shapes, and that the output is correct. This is required as we need to infer intermediate shape sizes."""
benchmark = np.einsum(expression, *operands)
result = contract(expression, *operands, optimize=True)
assert np.allclose(benchmark, result)
@@ -0,0 +1,279 @@
"""
Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests.
"""
from typing import Any, List
import pytest
from opt_einsum import contract, contract_path
from opt_einsum.typing import ArrayType
np = pytest.importorskip("numpy")
def build_views(string: str) -> List[ArrayType]:
"""Builds random numpy arrays for testing by using a fixed size dictionary and an input string."""
chars = "abcdefghij"
sizes_array = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4])
sizes = dict(zip(chars, sizes_array))
views = []
string = string.replace("...", "ij")
terms = string.split("->")[0].split(",")
for term in terms:
dims = [sizes[x] for x in term]
views.append(np.random.rand(*dims))
return views
def test_type_errors() -> None:
# subscripts must be a string
with pytest.raises(TypeError):
contract(0, 0)
# out parameter must be an array
with pytest.raises(TypeError):
contract("", 0, out="test")
# order parameter must be a valid order
# changed in Numpy 1.19, see https://github.com/numpy/numpy/commit/35b0a051c19265f5643f6011ee11e31d30c8bc4c
with pytest.raises((TypeError, ValueError)):
contract("", 0, order="W") # type: ignore
# casting parameter must be a valid casting
with pytest.raises(ValueError):
contract("", 0, casting="blah") # type: ignore
# dtype parameter must be a valid dtype
with pytest.raises(TypeError):
contract("", 0, dtype="bad_data_type")
# other keyword arguments are rejected
with pytest.raises(TypeError):
contract("", 0, bad_arg=0)
# issue 4528 revealed a segfault with this call
with pytest.raises(TypeError):
contract(*(None,) * 63)
# Cannot have two ->
with pytest.raises(ValueError):
contract("->,->", 0, 5)
# Undefined symbol lhs
with pytest.raises(ValueError):
contract("&,a->", 0, 5)
# Undefined symbol rhs
with pytest.raises(ValueError):
contract("a,a->&", 0, 5)
with pytest.raises(ValueError):
contract("a,a->&", 0, 5)
# Catch ellipsis errors
string = "...a->...a"
views = build_views(string)
# Subscript list must contain Ellipsis or (hashable && comparable) object
with pytest.raises(TypeError):
contract(views[0], [Ellipsis, 0], [Ellipsis, ["a"]])
with pytest.raises(TypeError):
contract(views[0], [Ellipsis, {}], [Ellipsis, "a"])
@pytest.mark.parametrize("contract_fn", [contract, contract_path])
def test_value_errors(contract_fn: Any) -> None:
with pytest.raises(ValueError):
contract_fn("")
# subscripts must be a string
with pytest.raises(TypeError):
contract_fn(0, 0)
# invalid subscript character
with pytest.raises(ValueError):
contract_fn("i%...", [0, 0])
with pytest.raises(ValueError):
contract_fn("...j$", [0, 0])
with pytest.raises(ValueError):
contract_fn("i->&", [0, 0])
with pytest.raises(ValueError):
contract_fn("")
# number of operands must match count in subscripts string
with pytest.raises(ValueError):
contract_fn("", 0, 0)
with pytest.raises(ValueError):
contract_fn(",", 0, [0], [0])
with pytest.raises(ValueError):
contract_fn(",", [0])
# can't have more subscripts than dimensions in the operand
with pytest.raises(ValueError):
contract_fn("i", 0)
with pytest.raises(ValueError):
contract_fn("ij", [0, 0])
with pytest.raises(ValueError):
contract_fn("...i", 0)
with pytest.raises(ValueError):
contract_fn("i...j", [0, 0])
with pytest.raises(ValueError):
contract_fn("i...", 0)
with pytest.raises(ValueError):
contract_fn("ij...", [0, 0])
# invalid ellipsis
with pytest.raises(ValueError):
contract_fn("i..", [0, 0])
with pytest.raises(ValueError):
contract_fn(".i...", [0, 0])
with pytest.raises(ValueError):
contract_fn("j->..j", [0, 0])
with pytest.raises(ValueError):
contract_fn("j->.j...", [0, 0])
# invalid subscript character
with pytest.raises(ValueError):
contract_fn("i%...", [0, 0])
with pytest.raises(ValueError):
contract_fn("...j$", [0, 0])
with pytest.raises(ValueError):
contract_fn("i->&", [0, 0])
# output subscripts must appear in input
with pytest.raises(ValueError):
contract_fn("i->ij", [0, 0])
# output subscripts may only be specified once
with pytest.raises(ValueError):
contract_fn("ij->jij", [[0, 0], [0, 0]])
# dimensions much match when being collapsed
with pytest.raises(ValueError):
contract_fn("ii", np.arange(6).reshape(2, 3))
with pytest.raises(ValueError):
contract_fn("ii->i", np.arange(6).reshape(2, 3))
# broadcasting to new dimensions must be enabled explicitly
with pytest.raises(ValueError):
contract_fn("i", np.arange(6).reshape(2, 3))
with pytest.raises(TypeError):
contract_fn("ij->ij", [[0, 1], [0, 1]], bad_kwarg=True)
@pytest.mark.parametrize(
"string",
[
# Ellipse
"...a->...",
"a...->...",
"a...a->...a",
"...,...",
"a,b",
"...a,...b",
],
)
def test_compare(string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(string, *views)
assert np.allclose(ein, opt)
opt = contract(string, *views, optimize="optimal")
assert np.allclose(ein, opt)
def test_ellipse_input1() -> None:
string = "...a->..."
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(views[0], [Ellipsis, 0], [Ellipsis])
assert np.allclose(ein, opt)
def test_ellipse_input2() -> None:
string = "...a"
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(views[0], [Ellipsis, 0])
assert np.allclose(ein, opt)
def test_ellipse_input3() -> None:
string = "...a->...a"
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(views[0], [Ellipsis, 0], [Ellipsis, 0])
assert np.allclose(ein, opt)
def test_ellipse_input4() -> None:
string = "...b,...a->..."
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(views[0], [Ellipsis, 1], views[1], [Ellipsis, 0], [Ellipsis])
assert np.allclose(ein, opt)
def test_singleton_dimension_broadcast() -> None:
# singleton dimensions broadcast (gh-10343)
p = np.ones((10, 2))
q = np.ones((1, 2))
ein = contract("ij,ij->j", p, q, optimize=False)
opt = contract("ij,ij->j", p, q, optimize=True)
assert np.allclose(ein, opt)
assert np.allclose(opt, [10.0, 10.0])
p = np.ones((1, 5))
q = np.ones((5, 5))
for optimize in (True, False):
res1 = (contract("...ij,...jk->...ik", p, p, optimize=optimize),)
res2 = contract("...ij,...jk->...ik", p, q, optimize=optimize)
assert np.allclose(res1, res2)
assert np.allclose(res2, np.full((1, 5), 5))
def test_large_int_input_format() -> None:
string = "ab,bc,cd"
x, y, z = build_views(string)
string_output = contract(string, x, y, z)
int_output = contract(x, (1000, 1001), y, (1001, 1002), z, (1002, 1003))
assert np.allclose(string_output, int_output)
for i in range(10):
transpose_output = contract(x, (i + 1, i))
assert np.allclose(transpose_output, x.T)
def test_hashable_object_input_format() -> None:
string = "ab,bc,cd"
x, y, z = build_views(string)
string_output = contract(string, x, y, z)
hash_output1 = contract(x, ("left", "bond1"), y, ("bond1", "bond2"), z, ("bond2", "right"))
hash_output2 = contract(
x,
("left", "bond1"),
y,
("bond1", "bond2"),
z,
("bond2", "right"),
("left", "right"),
)
assert np.allclose(string_output, hash_output1)
assert np.allclose(hash_output1, hash_output2)
for i in range(1, 10):
transpose_output = contract(x, ("b" * i, "a" * i))
assert np.allclose(transpose_output, x.T)
@@ -0,0 +1,74 @@
"""
Directly tests various parser utility functions.
"""
from typing import Any, Tuple
import pytest
from opt_einsum.parser import get_shape, get_symbol, parse_einsum_input
from opt_einsum.testing import build_arrays_from_tuples
def test_get_symbol() -> None:
assert get_symbol(2) == "c"
assert get_symbol(200000) == "\U00031540"
# Ensure we skip surrogates '[\uD800-\uDFFF]'
assert get_symbol(55295) == "\ud88b"
assert get_symbol(55296) == "\ue000"
assert get_symbol(57343) == "\ue7ff"
def test_parse_einsum_input() -> None:
eq = "ab,bc,cd"
ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)])
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *ops])
assert input_subscripts == eq
assert output_subscript == "ad"
assert operands == ops
def test_parse_einsum_input_shapes_error() -> None:
eq = "ab,bc,cd"
ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)])
with pytest.raises(ValueError):
_ = parse_einsum_input([eq, *ops], shapes=True)
def test_parse_einsum_input_shapes() -> None:
eq = "ab,bc,cd"
shapes = [(2, 3), (3, 4), (4, 5)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == eq
assert output_subscript == "ad"
assert shapes == operands
def test_parse_with_ellisis() -> None:
eq = "...a,ab"
shapes = [(2, 3), (3, 4)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == "da,ab"
assert output_subscript == "db"
assert shapes == operands
@pytest.mark.parametrize(
"array, shape",
[
[[5], (1,)],
[[5, 5], (2,)],
[(5, 5), (2,)],
[[[[[[5, 2]]]]], (1, 1, 1, 1, 2)],
[[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)],
["A", ()],
[b"A", ()],
[True, ()],
[5, ()],
[5.0, ()],
[5.0 + 0j, ()],
],
)
def test_get_shapes(array: Any, shape: Tuple[int]) -> None:
assert get_shape(array) == shape
@@ -0,0 +1,534 @@
"""
Tests the accuracy of the opt_einsum paths in addition to unit tests for
the various path helper functions.
"""
import itertools
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Dict, List, Optional
import pytest
import opt_einsum as oe
from opt_einsum.testing import build_shapes, rand_equation
from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType
explicit_path_tests = {
"GEMM1": (
[set("abd"), set("ac"), set("bdc")],
set(""),
{"a": 1, "b": 2, "c": 3, "d": 4},
),
"Inner1": (
[set("abcd"), set("abc"), set("bc")],
set(""),
{"a": 5, "b": 2, "c": 3, "d": 4},
),
}
# note that these tests have no unique solution due to the chosen dimensions
path_edge_tests = [
["greedy", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["branch-all", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["branch-2", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["optimal", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["dp", "eb,cb,fb->cef", ((1, 2), (0, 1))],
["greedy", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["branch-all", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["branch-2", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["dp", "dd,fb,be,cdb->cef", ((0, 3), (0, 2), (0, 1))],
["greedy", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["branch-all", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["branch-2", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["optimal", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["dp", "bca,cdb,dbf,afc->", ((1, 2), (1, 2), (0, 1))],
["greedy", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 1), (0, 1))],
["branch-all", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["branch-2", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["optimal", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["dp", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
]
# note that these tests have no unique solution due to the chosen dimensions
path_scalar_tests = [
[
"a,->a",
1,
],
["ab,->ab", 1],
[",a,->a", 2],
[",,a,->a", 3],
[",,->", 2],
]
def check_path(test_output: PathType, benchmark: PathType, bypass: bool = False) -> bool:
if not isinstance(test_output, list):
return False
if len(test_output) != len(benchmark):
return False
ret = True
for pos in range(len(test_output)):
ret &= isinstance(test_output[pos], tuple)
ret &= test_output[pos] == list(benchmark)[pos]
return ret
def assert_contract_order(func: Any, test_data: Any, max_size: int, benchmark: PathType) -> None:
test_output = func(test_data[0], test_data[1], test_data[2], max_size)
assert check_path(test_output, benchmark)
def test_size_by_dict() -> None:
sizes_dict = {}
for ind, val in zip("abcdez", [2, 5, 9, 11, 13, 0]):
sizes_dict[ind] = val
path_func = oe.helpers.compute_size_by_dict
assert 1 == path_func("", sizes_dict)
assert 2 == path_func("a", sizes_dict)
assert 5 == path_func("b", sizes_dict)
assert 0 == path_func("z", sizes_dict)
assert 0 == path_func("az", sizes_dict)
assert 0 == path_func("zbc", sizes_dict)
assert 104 == path_func("aaae", sizes_dict)
assert 12870 == path_func("abcde", sizes_dict)
def test_flop_cost() -> None:
size_dict = {v: 10 for v in "abcdef"}
# Loop over an array
assert 10 == oe.helpers.flop_count("a", False, 1, size_dict)
# Hadamard product (*)
assert 10 == oe.helpers.flop_count("a", False, 2, size_dict)
assert 100 == oe.helpers.flop_count("ab", False, 2, size_dict)
# Inner product (+, *)
assert 20 == oe.helpers.flop_count("a", True, 2, size_dict)
assert 200 == oe.helpers.flop_count("ab", True, 2, size_dict)
# Inner product x3 (+, *, *)
assert 30 == oe.helpers.flop_count("a", True, 3, size_dict)
# GEMM
assert 2000 == oe.helpers.flop_count("abc", True, 2, size_dict)
def test_bad_path_option() -> None:
with pytest.raises(KeyError):
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore
def test_explicit_path() -> None:
pytest.importorskip("numpy")
x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)])
assert x.item() == 6
def test_path_optimal() -> None:
test_func = oe.paths.optimal
test_data = explicit_path_tests["GEMM1"]
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
def test_path_greedy() -> None:
test_func = oe.paths.greedy
test_data = explicit_path_tests["GEMM1"]
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
def test_memory_paths() -> None:
expression = "abc,bdef,fghj,cem,mhk,ljk->adgl"
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5, shapes=True)
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5, shapes=True)
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
# Check the possibilities, greedy is capped
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1, shapes=True)
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1, shapes=True)
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
@pytest.mark.parametrize("alg,expression,order", path_edge_tests)
def test_path_edge_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
assert check_path(path_ret[0], order)
@pytest.mark.parametrize("expression,order", path_scalar_tests)
@pytest.mark.parametrize("alg", oe.paths._PATH_OPTIONS)
def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
# print(path_ret[0])
assert len(path_ret[0]) == order
def test_optimal_edge_cases() -> None:
# Edge test5
expression = "a,ac,ab,ad,cd,bd,bc->"
edge_test4 = build_shapes(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20})
path, _ = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
path, _ = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
def test_greedy_edge_cases() -> None:
expression = "abc,cfd,dbe,efa"
dim_dict = {k: 20 for k in expression.replace(",", "")}
tensors = build_shapes(expression, dimension_dict=dim_dict)
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1, 2, 3)])
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1, shapes=True)
assert check_path(path, [(0, 1), (0, 2), (0, 1)])
def test_dp_edge_cases_dimension_1() -> None:
eq = "nlp,nlq,pl->n"
shapes = [(1, 1, 1), (1, 1, 1), (1, 1)]
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
assert max(info.scale_list) == 3
def test_dp_edge_cases_all_singlet_indices() -> None:
eq = "a,bcd,efg->"
shapes = [(2,), (2, 2, 2), (2, 2, 2)]
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
assert max(info.scale_list) == 3
def test_custom_dp_can_optimize_for_outer_products() -> None:
eq = "a,b,abc->c"
da, db, dc = 2, 2, 3
shapes = [(da,), (db,), (da, db, dc)]
opt1 = oe.DynamicProgramming(search_outer=False)
opt2 = oe.DynamicProgramming(search_outer=True)
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
assert info2.opt_cost < info1.opt_cost
def test_custom_dp_can_optimize_for_size() -> None:
eq, shapes = rand_equation(10, 4, seed=43)
opt1 = oe.DynamicProgramming(minimize="flops")
opt2 = oe.DynamicProgramming(minimize="size")
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
assert info1.opt_cost < info2.opt_cost
assert info1.largest_intermediate > info2.largest_intermediate
def test_custom_dp_can_set_cost_cap() -> None:
eq, shapes = rand_equation(5, 3, seed=42)
opt1 = oe.DynamicProgramming(cost_cap=True)
opt2 = oe.DynamicProgramming(cost_cap=False)
opt3 = oe.DynamicProgramming(cost_cap=100)
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
info3 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt3)[1]
assert info1.opt_cost == info2.opt_cost == info3.opt_cost
@pytest.mark.parametrize(
"minimize,cost,width,path",
[
("flops", 663054, 18900, [(4, 5), (2, 5), (2, 7), (5, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("size", 1114440, 2016, [(2, 7), (3, 8), (3, 7), (2, 6), (1, 5), (1, 4), (1, 3), (1, 2), (0, 1)]),
("write", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo", 973518, 2016, [(4, 5), (2, 5), (6, 7), (2, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("limit", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo-256", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("limit-256", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
],
)
def test_custom_dp_can_set_minimize(minimize: str, cost: int, width: int, path: PathType) -> None:
eq, shapes = rand_equation(10, 4, seed=43)
opt = oe.DynamicProgramming(minimize=minimize)
info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1]
assert info.path == path
assert info.opt_cost == cost
assert info.largest_intermediate == width
def test_dp_errors_when_no_contractions_found() -> None:
eq, shapes = rand_equation(10, 3, seed=42)
# first get the actual minimum cost
opt = oe.DynamicProgramming(minimize="size")
_, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)
mincost = info.largest_intermediate
# check we can still find it without minimizing size explicitly
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize="dp")
# but check just below this threshold raises
with pytest.raises(RuntimeError):
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize="dp")
@pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"])
def test_can_optimize_outer_products(optimize: OptimizeKind) -> None:
a, b, c = ((10, 10) for _ in range(3))
d = (10, 2)
assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize, shapes=True)[0] == [
(2, 3),
(0, 2),
(0, 1),
]
@pytest.mark.parametrize("num_symbols", [2, 3, 26, 26 + 26, 256 - 140, 300])
def test_large_path(num_symbols: int) -> None:
symbols = "".join(oe.get_symbol(i) for i in range(num_symbols))
dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4])))
expression = ",".join(symbols[t : t + 2] for t in range(num_symbols - 1))
tensors = build_shapes(expression, dimension_dict=dimension_dict)
# Check that path construction does not crash
oe.contract_path(expression, *tensors, optimize="greedy", shapes=True)
def test_custom_random_greedy() -> None:
np = pytest.importorskip("numpy")
eq, shapes = rand_equation(10, 4, seed=42)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
oe.RandomGreedy(minimize="something")
optimizer = oe.RandomGreedy(max_repeats=10, minimize="flops")
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 10
assert len(optimizer.sizes) == 10
assert path == optimizer.path
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check can change settings and run again
optimizer.temperature = 0.0
optimizer.max_repeats = 6
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 16
assert len(optimizer.sizes) == 16
assert path == optimizer.path
assert optimizer.best["size"] == min(optimizer.sizes)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check error if we try and reuse the optimizer on a different expression
eq, shapes = rand_equation(10, 4, seed=41)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
def test_custom_branchbound() -> None:
np = pytest.importorskip("numpy")
eq, shapes = rand_equation(8, 4, seed=42)
views = list(map(np.ones, shapes))
optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize="size")
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert path == optimizer.path
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# tweak settings and run again
optimizer.nbranch = 3
optimizer.cutoff_flops_factor = 4
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert path == optimizer.path
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check error if we try and reuse the optimizer on a different expression
eq, shapes = rand_equation(8, 4, seed=41)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
def test_branchbound_validation() -> None:
with pytest.raises(ValueError):
oe.BranchBound(nbranch=0)
def test_parallel_random_greedy() -> None:
np = pytest.importorskip("numpy")
pool = ProcessPoolExecutor(2)
eq, shapes = rand_equation(10, 4, seed=42)
views = list(map(np.ones, shapes))
optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool)
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 10
assert len(optimizer.sizes) == 10
assert path == optimizer.path
assert optimizer.parallel is pool
assert optimizer._executor is pool
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# now switch to max time algorithm
optimizer.max_repeats = int(1e6)
optimizer.max_time = 0.2
optimizer.parallel = 2
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) > 10
assert len(optimizer.sizes) > 10
assert path == optimizer.path
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
optimizer.parallel = True
assert optimizer._executor is not None
assert optimizer._executor is not pool
are_done = [f.running() or f.done() for f in optimizer._futures]
assert all(are_done)
def test_custom_path_optimizer() -> None:
np = pytest.importorskip("numpy")
class NaiveOptimizer(oe.paths.PathOptimizer):
def __call__(
self,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
) -> PathType:
self.was_used = True
return [(0, 1)] * (len(inputs) - 1)
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
views = list(map(np.ones, shapes))
exp = oe.contract(eq, *views, optimize=False)
optimizer = NaiveOptimizer()
out = oe.contract(eq, *views, optimize=optimizer)
assert exp == out
assert optimizer.was_used
def test_custom_random_optimizer() -> None:
np = pytest.importorskip("numpy")
class NaiveRandomOptimizer(oe.path_random.RandomOptimizer):
@staticmethod
def random_path(
r: int, n: int, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int]
) -> Any:
"""Picks a completely random contraction order."""
np.random.seed(r)
ssa_path: List[TensorShapeType] = []
remaining = set(range(n))
while len(remaining) > 1:
i, j = np.random.choice(list(remaining), size=2, replace=False)
remaining.add(n + len(ssa_path))
remaining.remove(i)
remaining.remove(j)
ssa_path.append((i, j))
cost, size = oe.path_random.ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
return ssa_path, cost, size
def setup(self, inputs: Any, output: Any, size_dict: Any) -> Any:
self.was_used = True
n = len(inputs)
trial_fn = self.random_path
trial_args = (n, inputs, output, size_dict)
return trial_fn, trial_args
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
views = list(map(np.ones, shapes))
exp = oe.contract(eq, *views, optimize=False)
optimizer = NaiveRandomOptimizer(max_repeats=16)
out = oe.contract(eq, *views, optimize=optimizer)
assert exp == out
assert optimizer.was_used
assert len(optimizer.costs) == 16
def test_optimizer_registration() -> None:
def custom_optimizer(
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int]
) -> PathType:
return [(0, 1)] * (len(inputs) - 1)
with pytest.raises(KeyError):
oe.paths.register_path_fn("optimal", custom_optimizer)
oe.paths.register_path_fn("custom", custom_optimizer)
assert "custom" in oe.paths._PATH_OPTIONS
eq = "ab,bc,cd"
shapes = [(2, 3), (3, 4), (4, 5)]
path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore
assert path == [(0, 1), (0, 1)]
del oe.paths._PATH_OPTIONS["custom"]
def test_path_with_assumed_shapes() -> None:
path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]])
assert path == [(0, 1), (0, 1)]
@@ -0,0 +1,390 @@
import itertools
import weakref
from collections import Counter
from typing import Any
import pytest
from opt_einsum import contract, contract_expression, contract_path, get_symbol, shared_intermediates
from opt_einsum.backends import to_cupy, to_torch
from opt_einsum.contract import _einsum
from opt_einsum.parser import parse_einsum_input
from opt_einsum.sharing import count_cached_ops, currently_sharing, get_sharing_cache
from opt_einsum.testing import build_views
from opt_einsum.typing import BackendType
pytest.importorskip("numpy")
try:
import numpy as np # type: ignore
numpy_if_found = "numpy"
except ImportError:
numpy_if_found = pytest.param("numpy", marks=[pytest.mark.skip(reason="NumPy not installed.")]) # type: ignore
try:
import cupy # noqa
cupy_if_found = "cupy"
except ImportError:
cupy_if_found = pytest.param("cupy", marks=[pytest.mark.skip(reason="CuPy not installed.")]) # type: ignore
try:
import torch # type: ignore # noqa
torch_if_found = "torch"
except ImportError:
torch_if_found = pytest.param("torch", marks=[pytest.mark.skip(reason="PyTorch not installed.")]) # type: ignore
backends = [numpy_if_found, torch_if_found, cupy_if_found]
equations = [
"ab,bc->ca",
"abc,bcd,dea",
"abc,def->fedcba",
"abc,bcd,df->fa",
# test 'prefer einsum' ops
"ijk,ikj",
"i,j->ij",
"ijk,k->ij",
"AB,BC->CA",
]
to_backend = {
"numpy": lambda x: x,
"torch": to_torch,
"cupy": to_cupy,
}
@pytest.mark.parametrize("eq", equations)
@pytest.mark.parametrize("backend", backends)
def test_sharing_value(eq: str, backend: BackendType) -> None:
views = build_views(eq)
shapes = [v.shape for v in views]
expr = contract_expression(eq, *shapes)
expected = expr(*views, backend=backend)
with shared_intermediates():
actual = expr(*views, backend=backend)
assert (actual == expected).all()
@pytest.mark.parametrize("backend", backends)
def test_complete_sharing(backend: BackendType) -> None:
eq = "ab,bc,cd->"
views = build_views(eq)
expr = contract_expression(eq, *(v.shape for v in views))
print("-" * 40)
print("Without sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
expected = count_cached_ops(cache)
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
expr(*views, backend=backend)
actual = count_cached_ops(cache)
print("-" * 40)
print(f"Without sharing: {expected} expressions")
print(f"With sharing: {actual} expressions")
assert actual == expected
@pytest.mark.parametrize("backend", backends)
def test_sharing_reused_cache(backend: BackendType) -> None:
eq = "ab,bc,cd->"
views = build_views(eq)
expr = contract_expression(eq, *(v.shape for v in views))
print("-" * 40)
print("Without sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
expected = count_cached_ops(cache)
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
with shared_intermediates(cache):
expr(*views, backend=backend)
actual = count_cached_ops(cache)
print("-" * 40)
print(f"Without sharing: {expected} expressions")
print(f"With sharing: {actual} expressions")
assert actual == expected
@pytest.mark.parametrize("backend", backends)
def test_no_sharing_separate_cache(backend: BackendType) -> None:
eq = "ab,bc,cd->"
views = build_views(eq)
expr = contract_expression(eq, *(v.shape for v in views))
print("-" * 40)
print("Without sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
expected = count_cached_ops(cache)
expected.update(count_cached_ops(cache)) # we expect double
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache1:
expr(*views, backend=backend)
actual = count_cached_ops(cache1)
with shared_intermediates() as cache2:
expr(*views, backend=backend)
actual.update(count_cached_ops(cache2))
print("-" * 40)
print(f"Without sharing: {expected} expressions")
print(f"With sharing: {actual} expressions")
assert actual == expected
@pytest.mark.parametrize("backend", backends)
def test_sharing_nesting(backend: BackendType) -> None:
eqs = ["ab,bc,cd->a", "ab,bc,cd->b", "ab,bc,cd->c", "ab,bc,cd->c"]
views = build_views(eqs[0])
shapes = [v.shape for v in views]
refs: Any = weakref.WeakValueDictionary()
def method1(views):
with shared_intermediates():
w = contract_expression(eqs[0], *shapes)(*views, backend=backend)
x = contract_expression(eqs[2], *shapes)(*views, backend=backend)
result = contract_expression("a,b->", w.shape, x.shape)(w, x, backend=backend)
refs["w"] = w
refs["x"] = x
del w, x
assert "w" in refs
assert "x" in refs
assert "w" not in refs, "cache leakage"
assert "x" not in refs, "cache leakage"
return result
def method2(views):
with shared_intermediates():
y = contract_expression(eqs[2], *shapes)(*views, backend=backend)
z = contract_expression(eqs[3], *shapes)(*views, backend=backend)
refs["y"] = y
refs["z"] = z
result = contract_expression("c,d->", y.shape, z.shape)(y, z, backend=backend)
result = result + method1(views) # nest method1 in method2
del y, z
assert "y" in refs
assert "z" in refs
assert "y" not in refs
assert "z" not in refs
method1(views)
method2(views)
@pytest.mark.parametrize("eq", equations)
@pytest.mark.parametrize("backend", backends)
def test_sharing_modulo_commutativity(eq: str, backend: BackendType) -> None:
ops = tuple(to_backend[backend](x) for x in build_views(eq))
inputs, output, _ = parse_einsum_input([eq] + list(ops))
inputs_list = inputs.split(",")
print("-" * 40)
print("Without sharing:")
with shared_intermediates() as cache:
_einsum(eq, *ops, backend=backend)
expected = count_cached_ops(cache)
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache:
for permuted in itertools.permutations(zip(inputs_list, ops)):
permuted_inputs = [p[0] for p in permuted]
permuted_ops = [p[1] for p in permuted]
permuted_eq = "{}->{}".format(",".join(permuted_inputs), output)
_einsum(permuted_eq, *permuted_ops, backend=backend)
actual = count_cached_ops(cache)
print("-" * 40)
print(f"Without sharing: {expected} expressions")
print(f"With sharing: {actual} expressions")
assert actual == expected
@pytest.mark.parametrize("backend", backends)
def test_partial_sharing(backend: BackendType) -> None:
eq = "ab,bc,de->"
x, y, z1 = build_views(eq) # type: ignore
z2 = 2.0 * z1 - 1.0
expr = contract_expression(eq, x.shape, y.shape, z1.shape)
print("-" * 40)
print("Without sharing:")
num_exprs_nosharing: Any = Counter()
with shared_intermediates() as cache:
expr(x, y, z1, backend=backend)
num_exprs_nosharing.update(count_cached_ops(cache))
with shared_intermediates() as cache:
expr(x, y, z2, backend=backend)
num_exprs_nosharing.update(count_cached_ops(cache))
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache:
expr(x, y, z1, backend=backend)
expr(x, y, z2, backend=backend)
num_exprs_sharing = count_cached_ops(cache)
print("-" * 40)
print(f"Without sharing: {num_exprs_nosharing} expressions")
print(f"With sharing: {num_exprs_sharing} expressions")
assert num_exprs_nosharing["einsum"] > num_exprs_sharing["einsum"]
@pytest.mark.parametrize("backend", backends)
def test_sharing_with_constants(backend: BackendType) -> None:
inputs = "ij,jk,kl"
outputs = "ijkl"
equations = [f"{inputs}->{output}" for output in outputs]
shapes = (2, 3), (3, 4), (4, 5)
constants = {0, 2}
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[1])
expected = [contract_expression(eq, *shapes)(ops[0], var, ops[2]) for eq in equations]
with shared_intermediates():
actual = [contract_expression(eq, *ops, constants=constants)(var) for eq in equations]
for dim, expected_dim, actual_dim in zip(outputs, expected, actual):
assert np.allclose(expected_dim, actual_dim), f"error at {dim}"
@pytest.mark.parametrize("size", [3, 4, 5])
@pytest.mark.parametrize("backend", backends)
def test_chain(size: int, backend: BackendType) -> None:
xs = [np.random.rand(2, 2) for _ in range(size)]
shapes = [x.shape for x in xs]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i : i + 2] for i in range(size)]
inputs = ",".join(names)
with shared_intermediates():
print(inputs)
for i in range(size + 1):
target = alphabet[i]
eq = f"{inputs}->{target}"
path_info = contract_path(eq, *xs)
print(path_info[1])
expr = contract_expression(eq, *shapes)
expr(*xs, backend=backend)
print("-" * 40)
@pytest.mark.parametrize("size", [3, 4, 5, 10])
@pytest.mark.parametrize("backend", backends)
def test_chain_2(size: int, backend: BackendType) -> None:
xs = [np.random.rand(2, 2) for _ in range(size)]
shapes = [x.shape for x in xs]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i : i + 2] for i in range(size)]
inputs = ",".join(names)
with shared_intermediates():
print(inputs)
for i in range(size):
target = alphabet[i : i + 2]
eq = f"{inputs}->{target}"
path_info = contract_path(eq, *xs)
print(path_info[1])
expr = contract_expression(eq, *shapes)
expr(*xs, backend=backend)
print("-" * 40)
def _compute_cost(cache):
counts = count_cached_ops(cache)
return counts["einsum"] + counts["tensordot"]
@pytest.mark.parametrize("backend", backends)
def test_chain_2_growth(backend: BackendType) -> None:
sizes = list(range(1, 21))
costs = []
for size in sizes:
xs = [np.random.rand(2, 2) for _ in range(size)]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i : i + 2] for i in range(size)]
inputs = ",".join(names)
with shared_intermediates() as cache:
for i in range(size):
target = alphabet[i : i + 2]
eq = f"{inputs}->{target}"
expr = contract_expression(eq, *(x.shape for x in xs))
expr(*xs, backend=backend)
costs.append(_compute_cost(cache))
print(f"sizes = {repr(sizes)}")
print(f"costs = {repr(costs)}")
for size, cost in zip(sizes, costs):
print(f"{size}\t{cost}")
@pytest.mark.parametrize("size", [3, 4, 5])
@pytest.mark.parametrize("backend", backends)
def test_chain_sharing(size: int, backend: BackendType) -> None:
xs = [np.random.rand(2, 2) for _ in range(size)]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i : i + 2] for i in range(size)]
inputs = ",".join(names)
num_exprs_nosharing = 0
for i in range(size + 1):
with shared_intermediates() as cache:
target = alphabet[i]
eq = f"{inputs}->{target}"
expr = contract_expression(eq, *tuple(x.shape for x in xs))
expr(*xs, backend=backend)
num_exprs_nosharing += _compute_cost(cache)
with shared_intermediates() as cache:
print(inputs)
for i in range(size + 1):
target = alphabet[i]
eq = f"{inputs}->{target}"
path_info = contract_path(eq, *xs)
print(path_info[1])
expr = contract_expression(eq, *[x.shape for x in xs])
expr(*xs, backend=backend)
num_exprs_sharing = _compute_cost(cache)
print("-" * 40)
print(f"Without sharing: {num_exprs_nosharing} expressions")
print(f"With sharing: {num_exprs_sharing} expressions")
assert num_exprs_nosharing > num_exprs_sharing
def test_multithreaded_sharing() -> None:
from multiprocessing.pool import ThreadPool
def fn():
x, y, z = build_views("ab,bc,cd")
with shared_intermediates():
contract("ab,bc,cd->a", x, y, z)
contract("ab,bc,cd->b", x, y, z)
return len(get_sharing_cache())
expected = fn()
pool = ThreadPool(8)
fs = [pool.apply_async(fn) for _ in range(16)]
assert not currently_sharing()
assert [f.get() for f in fs] == [expected] * 16
pool.close()