Skip to content

Commit

Permalink
einsum from python to C interface
Browse files Browse the repository at this point in the history
  • Loading branch information
raghavendrak committed Jul 11, 2024
1 parent 81996d2 commit 1892ba7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src_python/ctf/multilinear.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ def spttn_kernel(tensor A, tsrs_list, nBs, einsum_expr):
t = tsrs_list[i]
tsrs[i] = <Tensor[double]*>t.dt
B = tensor(copy=A)
# python einsum expression to C einsum expression
inp_tsrs, op_tsr = (einsum_expr.split("->")[0].split(","), einsum_expr.split("->")[1])
mapping = {inp_tsrs[0][i]: inp_tsrs[0][-i-1] for i in range(len(inp_tsrs[0]))}
tr_einsum_expr = ''.join([inp_tsrs[0], ","]) + ','.join(inp_tsrs[i][::-1].translate(str.maketrans(mapping)) for i in range(1, len(inp_tsrs))) + "->" + op_tsr[::-1].translate(str.maketrans(mapping))
einsum_expr = tr_einsum_expr.encode()

if A.dtype == np.float64:
spttn_kernel_[double](<Tensor[double]*>B.dt,tsrs,nBs,<char *>einsum_expr)
else:
Expand Down
8 changes: 4 additions & 4 deletions test/python/test_spttn_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_ttmc(self):
Test spttn kernel for order 3 TTMc
and all-mode order 3 TTMc
"""
einsum_expr = "ijk,ri,sj->rsk"
einsum_expr = "ijk,kr,js->isr"
lens = [10,10,10]
R = 4
A = ctf.tensor(lens,sp=True)
Expand All @@ -27,13 +27,13 @@ def test_ttmc(self):
tsrs.append(ctf.random.random(fac_lens))
op_lens = [lens[2],R,R]
tsrs.append(ctf.zeros(op_lens))
ctf.spttn_kernel(A, tsrs, 3, einsum_expr.encode())
ctf.spttn_kernel(A, tsrs, 3, einsum_expr)
ctr = A.i("ijk")*tsrs[0].i("kr")*tsrs[1].i("js")
ans = ctf.zeros(op_lens)
ans.i("isr") << ctr
self.assertTrue(allclose(ans, tsrs[2]))

einsum_expr = "ijk,ri,sj,tk->rst"
einsum_expr = "ijk,kr,js,it->tsr"
A = ctf.tensor(lens,sp=True)
A.fill_sp_random(-1.,1.,.5)
tsrs = []
Expand All @@ -42,7 +42,7 @@ def test_ttmc(self):
tsrs.append(ctf.random.random(fac_lens))
op_lens = [R,R,R]
tsrs.append(ctf.zeros(op_lens))
ctf.spttn_kernel(A, tsrs, 4, einsum_expr.encode())
ctf.spttn_kernel(A, tsrs, 4, einsum_expr)
ctr = A.i("ijk")*tsrs[0].i("kr")*tsrs[1].i("js")*tsrs[2].i("it")
ans = ctf.zeros(op_lens)
ans.i("tsr") << ctr
Expand Down

0 comments on commit 1892ba7

Please sign in to comment.