diff --git a/src_python/ctf/multilinear.pyx b/src_python/ctf/multilinear.pyx index 4e9b4b38..5251791e 100644 --- a/src_python/ctf/multilinear.pyx +++ b/src_python/ctf/multilinear.pyx @@ -193,6 +193,12 @@ def spttn_kernel(tensor A, tsrs_list, nBs, einsum_expr): t = tsrs_list[i] tsrs[i] = t.dt B = tensor(copy=A) + # python einsum expression to C einsum expression + inp_tsrs, op_tsr = (einsum_expr.split("->")[0].split(","), einsum_expr.split("->")[1]) + mapping = {inp_tsrs[0][i]: inp_tsrs[0][-i-1] for i in range(len(inp_tsrs[0]))} + tr_einsum_expr = ''.join([inp_tsrs[0], ","]) + ','.join(inp_tsrs[i][::-1].translate(str.maketrans(mapping)) for i in range(1, len(inp_tsrs))) + "->" + op_tsr[::-1].translate(str.maketrans(mapping)) + einsum_expr = tr_einsum_expr.encode() + if A.dtype == np.float64: spttn_kernel_[double](B.dt,tsrs,nBs,einsum_expr) else: diff --git a/test/python/test_spttn_kernel.py b/test/python/test_spttn_kernel.py index 2ec9da7c..8c84b374 100644 --- a/test/python/test_spttn_kernel.py +++ b/test/python/test_spttn_kernel.py @@ -16,7 +16,7 @@ def test_ttmc(self): Test spttn kernel for order 3 TTMc and all-mode order 3 TTMc """ - einsum_expr = "ijk,ri,sj->rsk" + einsum_expr = "ijk,kr,js->isr" lens = [10,10,10] R = 4 A = ctf.tensor(lens,sp=True) @@ -27,13 +27,13 @@ def test_ttmc(self): tsrs.append(ctf.random.random(fac_lens)) op_lens = [lens[2],R,R] tsrs.append(ctf.zeros(op_lens)) - ctf.spttn_kernel(A, tsrs, 3, einsum_expr.encode()) + ctf.spttn_kernel(A, tsrs, 3, einsum_expr) ctr = A.i("ijk")*tsrs[0].i("kr")*tsrs[1].i("js") ans = ctf.zeros(op_lens) ans.i("isr") << ctr self.assertTrue(allclose(ans, tsrs[2])) - einsum_expr = "ijk,ri,sj,tk->rst" + einsum_expr = "ijk,kr,js,it->tsr" A = ctf.tensor(lens,sp=True) A.fill_sp_random(-1.,1.,.5) tsrs = [] @@ -42,7 +42,7 @@ def test_ttmc(self): tsrs.append(ctf.random.random(fac_lens)) op_lens = [R,R,R] tsrs.append(ctf.zeros(op_lens)) - ctf.spttn_kernel(A, tsrs, 4, einsum_expr.encode()) + ctf.spttn_kernel(A, tsrs, 4, einsum_expr) ctr = A.i("ijk")*tsrs[0].i("kr")*tsrs[1].i("js")*tsrs[2].i("it") ans = ctf.zeros(op_lens) ans.i("tsr") << ctr