Skip to content

Commit

Permalink
buf dim as function parameter; add optimize_for_blas call when pickin…
Browse files Browse the repository at this point in the history
…g indices for a single term
  • Loading branch information
raghavendrak committed Jul 9, 2024
1 parent ad3b903 commit bfbbb88
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 45 deletions.
61 changes: 48 additions & 13 deletions examples/spttn_tucker_solve_kernels.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ bool execute_spttn_kernel(int n, int ur, int vr, int wr,
UCxx.norm2(norm);
int64_t sz = T.get_tot_size(false);
bool pass = (norm / sz < 1.e-5);
if (dw.rank == 0){
if (dw.rank == 0) {
if (!pass)
printf("Test passed.\n");
else
printf("Test failed.\n");
else
printf("Test passed.\n");
}
IASSERT(pass);
mpass = mpass & pass;
Expand Down Expand Up @@ -247,11 +247,11 @@ bool execute_spttn_kernel(int n, int ur, int vr, int wr,
UCxx.norm2(norm);
int64_t sz = T.get_tot_size(false);
bool pass = (norm / sz < 1.e-5);
if (dw.rank == 0){
if (dw.rank == 0) {
if (!pass)
printf("Test passed.\n");
else
printf("Test failed.\n");
else
printf("Test passed.\n");
}
IASSERT(pass);
mpass = mpass & pass;
Expand Down Expand Up @@ -358,11 +358,11 @@ bool execute_spttn_kernel(int n, int ur, int vr, int wr,
UCxx.norm2(norm);
int64_t sz = T.get_tot_size(false);
bool pass = (norm / sz < 1.e-5);
if (dw.rank == 0){
if (dw.rank == 0) {
if (!pass)
printf("Test passed.\n");
else
printf("Test failed.\n");
else
printf("Test passed.\n");
}
IASSERT(pass);
mpass = mpass & pass;
Expand Down Expand Up @@ -395,6 +395,40 @@ bool execute_spttn_kernel(int n, int ur, int vr, int wr,
for a:
buf2 += buf[a] * U[a,i]
Z_ijk += buf2 * T_ijk
with thres_buf_sz = 2
path chosen: 0
ta: 8 tb: 16 tab: 24 inds: 28
ta: 4 tb: 24 tab: 28 inds: 14
ta: 2 tb: 28 tab: 30 inds: 7
ta: 1 tb: 30 tab: 31 inds: 7
total loop depth: 15
term id 0: 4 8 16 32
term id 1: 4 2 8 16
term id 2: 4 2 1 8
term id 3: 4 1 2
niloops: 6
i: 1 j: 2 k: 4 a: 8 b: 16 c: 32
T: 1 U: 2 V: 4 W: 8 C: 16
4 + 4 + 4 + 3 = 15
for k:
for a:
for b:
for c:
buf[a,b,c] += W[c,k] * C[a,b,c]
for j:
for a:
for b:
buf2[a] += buf[a,b,c] * V[b,j]
for i:
for a:
buf[i] += buf2[a] * U[a,i]
for i:
for j:
Z_ijk += buf[i] * T[i,j,k]
*/

lens_uc[0] = ur; lens_uc[1] = n1;
Expand All @@ -418,8 +452,9 @@ bool execute_spttn_kernel(int n, int ur, int vr, int wr,
C.fill_random((dtype)0,(dtype)1);

Tensor<dtype> * ops[5] = {&U, &V, &W, &C, &UC};
int max_buf_dim = 2;
stime = MPI_Wtime();
spttn_kernel<dtype>(&T, ops, 5, "ijk,ai,bj,ck,abc->ijk");
spttn_kernel<dtype>(&T, ops, 5, "ijk,ai,bj,ck,abc->ijk", max_buf_dim);
etime = MPI_Wtime();
if (dw.rank == 0) printf("ijk,ai,bj,ck,abc->ijk using SpTTN-Cyclops (NOTE that it includes CSF construction time; please see total time to calculate printed above): %1.2lf\n", (etime - stime));

Expand All @@ -433,11 +468,11 @@ bool execute_spttn_kernel(int n, int ur, int vr, int wr,
UCxx.norm2(norm);
int64_t sz = T.get_tot_size(false);
bool pass = (norm / sz < 1.e-5);
if (dw.rank == 0){
if (dw.rank == 0) {
if (!pass)
printf("Test passed.\n");
else
printf("Test failed.\n");
else
printf("Test passed.\n");
}
IASSERT(pass);
mpass = mpass & pass;
Expand Down
8 changes: 4 additions & 4 deletions src/interface/multilinear.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1004,23 +1004,23 @@ namespace CTF {
}

template<typename dtype>
void spttn_kernel(Tensor<dtype> * A, Tensor<dtype> ** Bs, int nBs, const char * einsum_expr, std::string * terms, int nterms, std::string * index_order)
void spttn_kernel(Tensor<dtype> * A, Tensor<dtype> ** Bs, int nBs, const char * einsum_expr, std::string * terms, int nterms, std::string * index_order, int max_buf_dim)
{
IASSERT(terms != nullptr && index_order != nullptr);
char * idx_A;
char ** idx_Bs;
CTF_int::parse_einsum(einsum_expr, &idx_A, &idx_Bs, nBs);
CTF_int::spttn_contraction<dtype> ctr = CTF_int::spttn_contraction<dtype>(A, idx_A, (CTF_int::tensor **)Bs, nBs, idx_Bs, terms, nterms, index_order);
CTF_int::spttn_contraction<dtype> ctr = CTF_int::spttn_contraction<dtype>(A, idx_A, (CTF_int::tensor **)Bs, nBs, idx_Bs, terms, nterms, index_order, max_buf_dim);
ctr.execute();
}

template<typename dtype>
void spttn_kernel(Tensor<dtype> *A, Tensor<dtype> **Bs, int nBs, const char *einsum_expr)
void spttn_kernel(Tensor<dtype> *A, Tensor<dtype> **Bs, int nBs, const char *einsum_expr, int max_buf_dim)
{
char *idx_A;
char **idx_Bs;
CTF_int::parse_einsum(einsum_expr, &idx_A, &idx_Bs, nBs);
CTF_int::spttn_contraction<dtype> ctr = CTF_int::spttn_contraction<dtype>(A, idx_A, (CTF_int::tensor **)Bs, nBs, idx_Bs);
CTF_int::spttn_contraction<dtype> ctr = CTF_int::spttn_contraction<dtype>(A, idx_A, (CTF_int::tensor **)Bs, nBs, idx_Bs, max_buf_dim);
ctr.execute();
}

Expand Down
4 changes: 2 additions & 2 deletions src/interface/multilinear.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ namespace CTF {
void Solve_Factor(Tensor<dtype> * T, Tensor<dtype> ** mat_list, Tensor<dtype> * RHS, int mode, bool aux_mode_first);

template<typename dtype>
void spttn_kernel(Tensor<dtype> * A, Tensor<dtype> ** Bs, int nBs, const char * einsum_expr, std::string * terms, int nterms, std::string * index_order);
void spttn_kernel(Tensor<dtype> * A, Tensor<dtype> ** Bs, int nBs, const char * einsum_expr, std::string * terms, int nterms, std::string * index_order, int max_buf_dim = 2);

template<typename dtype>
void spttn_kernel(Tensor<dtype> * A, Tensor<dtype> ** B, int nBs, const char * einsum_expr);
void spttn_kernel(Tensor<dtype> * A, Tensor<dtype> ** B, int nBs, const char * einsum_expr, int max_buf_dim = 2);
}

#endif
78 changes: 58 additions & 20 deletions src/spttn_cyclops/cp_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,31 @@ namespace CTF_int {
return false;
}

// populate a single term
void optimize_for_blas (uint16_t S, uint8_t sT, uint8_t eT, uint16_t rem_inds)
{
// first push sparse indices
uint16_t cp_S = S;
for (int i = nindices; i >= 0; i--) {
if (((1 << i) & sp_inds) && (rem_inds & (1 << i))) {
if (apply_constraints(cp_S, sT, eT, (1<<i), 2)) {
assert(0);
}
icache[S][sT][eT].inds_order[0][0].push_back((1<<i));
cp_S |= (1 << i);
rem_inds &= ~(1 << i);
}
}
// then push dense indices
for (int i = 0; i < nindices; i++) {
if (rem_inds & (1 << i)) {
icache[S][sT][eT].inds_order[0][0].push_back((1<<i));
rem_inds &= ~(1 << i);
}
}
assert(rem_inds == 0);
}

// (sT, eT) interval of terms
void io_cost (uint16_t S,
uint8_t sT, uint8_t eT)
Expand Down Expand Up @@ -525,11 +550,15 @@ namespace CTF_int {
icache[S][sT][eT].max_buf_sz[0] = icache[S][sT][eT].max_buf_sz[1] = 0;
assert(icache[S][sT][eT].inds_order[0][0].size() == 0);
// populate the first index order
// potential to optimize for BLAS calls
optimize_for_blas(S, sT, eT, rem_inds);
/*
for (int i = 0; i < nindices; i++) {
if (rem_inds & (1 << i)) {
icache[S][sT][eT].inds_order[0][0].push_back((1<<i));
}
}
*/
if (icache[S][sT][eT].inds_order[0][0].size() == 0) {
assert(icache[S][sT][eT].inds_order[1][0].size() == 0);
}
Expand All @@ -542,10 +571,16 @@ namespace CTF_int {
// populate the second index order
// interchange the first two indices in the first two index order and record it as the second index order
assert(icache[S][sT][eT].inds_order[0][0][0] != icache[S][sT][eT].inds_order[0][0][1]);
icache[S][sT][eT].inds_order[1][0].push_back(icache[S][sT][eT].inds_order[0][0][1]);
icache[S][sT][eT].inds_order[1][0].push_back(icache[S][sT][eT].inds_order[0][0][0]);
icache[S][sT][eT].inds_order[1][0].insert(icache[S][sT][eT].inds_order[1][0].end(), icache[S][sT][eT].inds_order[0][0].begin()+2, icache[S][sT][eT].inds_order[0][0].end());
assert(icache[S][sT][eT].inds_order[1][0].size() == icache[S][sT][eT].inds_order[0][0].size());
// can the indices be switched?
if (apply_constraints(S, sT, eT, icache[S][sT][eT].inds_order[0][0][1], 2)) {
assert(icache[S][sT][eT].inds_order[1][0].size() == 0);
}
else {
icache[S][sT][eT].inds_order[1][0].push_back(icache[S][sT][eT].inds_order[0][0][1]);
icache[S][sT][eT].inds_order[1][0].push_back(icache[S][sT][eT].inds_order[0][0][0]);
icache[S][sT][eT].inds_order[1][0].insert(icache[S][sT][eT].inds_order[1][0].end(), icache[S][sT][eT].inds_order[0][0].begin()+2, icache[S][sT][eT].inds_order[0][0].end());
assert(icache[S][sT][eT].inds_order[1][0].size() == icache[S][sT][eT].inds_order[0][0].size());
}
}
// find dense indices after all sparse indices are removed and populate independent dense loops
for (int j = 0; j < 2; j++) {
Expand Down Expand Up @@ -699,6 +734,10 @@ namespace CTF_int {
if (icache[S][s+1][eT].inds_order[0][0].size() == 0) {
// nothing to do; the term has already been iterated over at this level
}
else if (icache[S][s+1][eT].computed == false) {
// should just have continue
assert(0);
}
else if (q == icache[S][s+1][eT].inds_order[0][0][0]) {
// term in the R branch has the same indices as the term in the L branch
if (icache[S][s+1][eT].inds_order[1][0].size() == 0) {
Expand Down Expand Up @@ -792,23 +831,22 @@ namespace CTF_int {
}
}
}
}
if (niloopss[0] == -1) {
// could not find a loop nest within the specified cost
return;
}
// update icache
assert(niloopss[0] != -1);
assert (icache[S][sT][eT].computed == false);

for (int j = 0; j < 2; j++) {
icache[S][sT][eT].niloops[j] = niloopss[j];
icache[S][sT][eT].max_buf_sz[j] = max_buf_szs[j];
icache[S][sT][eT].inds_order[j] = std::move(Ts[j]);
}
icache[S][sT][eT].computed = true;
}
};
if (niloopss[0] == -1) {
// could not find a loop nest within the specified cost
return;
}
// update icache
assert(niloopss[0] != -1);
assert (icache[S][sT][eT].computed == false);

for (int j = 0; j < 2; j++) {
icache[S][sT][eT].niloops[j] = niloopss[j];
icache[S][sT][eT].max_buf_sz[j] = max_buf_szs[j];
icache[S][sT][eT].inds_order[j] = std::move(Ts[j]);
}
icache[S][sT][eT].computed = true;
}
};
}
#endif
7 changes: 7 additions & 0 deletions src/spttn_cyclops/csf.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace CTF_int {
int64_t ** idx;
int64_t ** ptr;
CTF::Pair<dtype> * dt;
// for the sparse output tensor in SpTTN kernels that have the same sparsity pattern as the input tensor
CTF::Pair<dtype> * dt_sp_op;
int64_t * ldas;
int64_t * nnz_level;
int * phys_phase;
Expand Down Expand Up @@ -121,6 +123,11 @@ namespace CTF_int {
return dt[pt].d;
}

void init_sp_op(CTF::Pair<dtype> * pairs)
{
dt_sp_op = pairs;
}

void traverse_CSF(int64_t st_ptr,
int64_t en_ptr,
int level)
Expand Down
11 changes: 8 additions & 3 deletions src/spttn_cyclops/prepare_kernel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ namespace CTF_int {
const std::string * terms_,
int nterms_,
const std::string & sindex_order_,
int max_buf_dim_,
bool retain_op_,
tensor ** redis_op_,
char const * alpha_,
Expand All @@ -76,6 +77,7 @@ namespace CTF_int {
tensor ** Bs_,
int nBs_,
const char * const * cidx_Bs,
int max_buf_dim_,
bool retain_op_,
tensor ** redis_op_,
char const * alpha_,
Expand All @@ -86,6 +88,7 @@ namespace CTF_int {
nBs = nBs_;
// nBs is number of Bs including the ouput. The number of input tensors including A is nBs
nterms = nBs-1;
max_buf_dim = max_buf_dim_;
retain_op = retain_op_;
redis_op = redis_op_;
func = func_;
Expand Down Expand Up @@ -128,7 +131,7 @@ namespace CTF_int {
for (int i = 0; i < nterms; i++) {
new(terms + i) contraction_terms<dtype>(dim_max, nBs);
}
select_cp_io_populate_terms(cidx_A, cidx_Bs, nterms, terms, order_A, idx_A, nBs, order_Bs, idx_Bs, num_indices, A->wrld->rank);
select_cp_io_populate_terms(cidx_A, cidx_Bs, nterms, terms, order_A, idx_A, nBs, order_Bs, idx_Bs, num_indices, max_buf_dim, A->wrld->rank);
}

template<typename dtype>
Expand All @@ -140,6 +143,7 @@ namespace CTF_int {
const std::string * sterms,
int nterms_,
const std::string * sindex_order,
int max_buf_dim_,
bool retain_op_,
tensor ** redis_op_,
char const * alpha_,
Expand All @@ -149,6 +153,7 @@ namespace CTF_int {
Bs = Bs_;
nBs = nBs_;
nterms = nterms_;
max_buf_dim = max_buf_dim_;
retain_op = retain_op_;
redis_op = redis_op_;
func = func_;
Expand Down Expand Up @@ -207,6 +212,7 @@ namespace CTF_int {
const std::string * sterms,
int nterms_,
const std::string & sindex_order,
int max_buf_dim_,
bool retain_op_,
tensor ** redis_op_,
char const * alpha_,
Expand Down Expand Up @@ -352,8 +358,7 @@ namespace CTF_int {
if (A->wrld->rank == 0) printf("tree construction time: %1.2lf\n", (etime - stime));

if (Bs[nBs-1]->is_sparse) {
// allocate the sparse output tensor in the same tree structure as the input tensor
IASSERT(0);
A_csf.init_sp_op((Pair<dtype>*)Bs[nBs-1]->data);
}

stime = MPI_Wtime();
Expand Down
Loading

0 comments on commit bfbbb88

Please sign in to comment.