Skip to content

Commit

Permalink
using ddot and prepare xgemv fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
raghavendrak committed Jul 23, 2024
1 parent e5e0db8 commit 93d853e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 20 deletions.
10 changes: 8 additions & 2 deletions src/spttn_cyclops/execute_kernel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,14 @@ namespace CTF_int {
double * dY = (double *)Bs[term.Y];
double * dA = (double *)Bs[term.A];
double BETA = 1.;
const char TRANS = 'N';
CTF_BLAS::DGEMV(&TRANS, &term.M, &term.N, &term.ALPHA, dA, &term.LDA, dX, &term.INCX, &BETA, dY, &term.INCY);
CTF_BLAS::DGEMV(&term.TRANS, &term.M, &term.N, &term.ALPHA, dA, &term.LDA, dX, &term.INCX, &BETA, dY, &term.INCY);
}
break;
case xDOT: {
double * dX = (double *)Bs[term.X];
double * dY = (double *)Bs[term.Y];
double * dot = (double *)Bs[(int)term.ALPHA];
*dot = CTF_BLAS::DDOT(&term.N, dX, &term.INCX, dY, &term.INCY);
}
break;
default: {
Expand Down
1 change: 1 addition & 0 deletions src/spttn_cyclops/execute_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace CTF_int{
double BETA;
int M;
int N;
char TRANS;

bool * dense_sp_loop;
int dense_sp_loop_in_term;
Expand Down
80 changes: 62 additions & 18 deletions src/spttn_cyclops/prepare_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ namespace CTF_int {
};

enum BLAS_OPERANDS {MAIN_TENSOR, PREV_TERM_BUF, CURRENT_TERM_BUF, INP_B, INTERMEDIATE_TENSOR};
enum BREAK_REC {NOT_SET, SCALAR, SPARSE, SPARSE_xAXPY, SPARSE_xAXPY_OP_NOT_BUFFER, DENSE, DENSE_xAXPY_2D, DENSE_xAXPY_3D, DENSE_STRIDED, DENSE_3D, DENSE_3D_TO_xAXPY, xAXPY, xVEC_MUL, xGER, xGER_TO_xAXPY, xGEMV, RECURSIVE_LOOP};
enum BREAK_REC {NOT_SET, SCALAR, SPARSE, SPARSE_xAXPY, SPARSE_xAXPY_OP_NOT_BUFFER, DENSE, DENSE_xAXPY_2D, DENSE_xAXPY_3D, DENSE_STRIDED, DENSE_3D, DENSE_3D_TO_xAXPY, xAXPY, xVEC_MUL, xGER, xDOT, xGER_TO_xAXPY, xGEMV, RECURSIVE_LOOP};

template <typename dtype>
const char* enumToStr(BREAK_REC br) {
Expand All @@ -163,6 +163,7 @@ namespace CTF_int {
case xAXPY: return "xAXPY";
case xVEC_MUL: return "xVEC_MUL";
case xGER: return "xGER";
case xDOT: return "xDOT";
case xGEMV: return "xGEMV";
case RECURSIVE_LOOP: return "RECURSIVE_LOOP";
default: return "Unknown value";
Expand Down Expand Up @@ -432,6 +433,10 @@ namespace CTF_int {
term.INCY = lda_Bs[term.Y][idx];
IASSERT(term.INCX == 1 && term.INCY == 1);
term.N = len_idx[idx];
if (term.Y == nBs - 1) {
// the output has a sparse index in this case
term.blas_kernel = SPARSE_xAXPY_OP_NOT_BUFFER;
}
/*
if (i != (nterms-1)) {
term.X = term.blas_B_ids[1]; // an input tensor
Expand Down Expand Up @@ -688,15 +693,29 @@ namespace CTF_int {
idx_Y = term.Y >= nBs ? terms[term.Y-nBs].idx_tbuffer[0] : idx_Bs[term.Y][0];
// input
int idx_X1 = term.blas_B_ids[0] >= nBs ? terms[term.blas_B_ids[0]-nBs].idx_tbuffer[0] : idx_Bs[term.blas_B_ids[0]][0];
int idx_Y1 = term.blas_B_ids[0] >= nBs ? terms[term.blas_B_ids[0]-nBs].idx_tbuffer[1] : idx_Bs[term.blas_B_ids[0]][1];
int idx_X2 = term.blas_B_ids[1] >= nBs ? terms[term.blas_B_ids[1]-nBs].idx_tbuffer[0] : idx_Bs[term.blas_B_ids[1]][0];
if (idx_X1 == idx_Y) {
int idx_Y2 = term.blas_B_ids[1] >= nBs ? terms[term.blas_B_ids[1]-nBs].idx_tbuffer[1] : idx_Bs[term.blas_B_ids[1]][1];
if (idx_X1 == idx_Y || idx_Y1 == idx_Y) {
// first input tensor is A
term.A = term.blas_B_ids[0];
int idx_Y1 = term.blas_B_ids[0] >= nBs ? terms[term.blas_B_ids[0]-nBs].idx_tbuffer[1] : idx_Bs[term.blas_B_ids[0]][1];
if (idx_Y1 != idx_X2) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
if (idx_X1 == idx_Y) {
if (idx_Y1 != idx_X2) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
}
// no transpose
term.TRANS = 'N';
}
if (idx_Y1 == idx_Y) {
if (idx_X1 != idx_X2) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
}
// transpose A
term.TRANS = 'T';
}
term.M = len_idx[idx_X1];
term.N = len_idx[idx_Y1];
Expand All @@ -706,11 +725,23 @@ namespace CTF_int {
else {
// second input tensor is A
term.A = term.blas_B_ids[1];
int idx_Y2 = term.blas_B_ids[1] >= nBs ? terms[term.blas_B_ids[1]-nBs].idx_tbuffer[1] : idx_Bs[term.blas_B_ids[1]][1];
if (idx_Y2 != idx_X1) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
if (idx_X2 == idx_Y) {
if (idx_Y2 != idx_X1) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
}
// no transpose
term.TRANS = 'N';
}
if (idx_Y2 == idx_Y) {
if (idx_X2 != idx_X1) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
}
// transpose A
term.TRANS = 'T';
}
term.M = len_idx[idx_X2];
term.N = len_idx[idx_Y2];
Expand All @@ -721,6 +752,16 @@ namespace CTF_int {
term.INCY = lda_Bs[term.Y][idx_Y];
}
break;
case xDOT: {
// dot <- x^T * y
term.ALPHA = term.blas_B_ids[2];
term.X = term.blas_B_ids[0];
term.Y = term.blas_B_ids[1];
term.INCX = lda_Bs[term.X][term.blas_idx];
term.INCY = lda_Bs[term.Y][term.blas_idx];
term.N = len_idx[term.blas_idx];
}
break;
default:
break;
}
Expand Down Expand Up @@ -750,7 +791,7 @@ namespace CTF_int {
const int rank)
{
debug_spttn_cyclops spttn_print;
IASSERT(nidx_term[2] > 0);
// IASSERT(nidx_term[2] > 0);
contraction_terms<dtype> & term = terms[term_id];
bool recursive_loop = false;
if (recursive_loop == true) {
Expand All @@ -765,14 +806,15 @@ namespace CTF_int {
if (term.sparse_idx != -1) {
term.blas_idx = term.sparse_idx;
term.blas_kernel = SPARSE_xAXPY;
// term.blas_kernel = RECURSIVE_LOOP;
if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "SPARSE_xAXPY" << std::endl;
}
else {
term.blas_kernel = xAXPY;
if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "xAXPY" << std::endl;
}
}
else if (nidx_term[0] == 1 && nidx_term[1] == 1) {
else if (nidx_term[0] == 1 && nidx_term[1] == 1 && nidx_term[2] == 1) {
int idx_X, idx_Y;
idx_X = term.blas_B_ids[0] >= nBs ? terms[term.blas_B_ids[0]-nBs].idx_tbuffer[0] : idx_Bs[term.blas_B_ids[0]][0];
idx_Y = term.blas_B_ids[1] >= nBs ? terms[term.blas_B_ids[1]-nBs].idx_tbuffer[0] : idx_Bs[term.blas_B_ids[1]][0];
Expand All @@ -793,6 +835,10 @@ namespace CTF_int {
term.blas_kernel = xVEC_MUL;
if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "xVEC_MUL" << std::endl;
}
else if (nidx_term[0] == 1 && nidx_term[1] == 1 && nidx_term[2] == 0) {
term.blas_kernel = xDOT;
if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "xDOT" << std::endl;
}
else {
// TODO: ttmc_o3_allm rkji rskj trks
// FIXME: should have been handled in process_inner_ids
Expand All @@ -810,7 +856,6 @@ namespace CTF_int {
if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "RECURSIVE_LOOP" << std::endl;
}
else if ((nidx_term[0] == 1 && nidx_term[1] == 2 && nidx_term[2] == 1) || (nidx_term[0] == 2 && nidx_term[1] == 1 && nidx_term[2] == 1)) {
// TODO: tucker_solve TTTP term 1 a <- abc bj
term.blas_kernel = xGEMV;
if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "xGEMV" << std::endl;
}
Expand Down Expand Up @@ -840,7 +885,6 @@ namespace CTF_int {
}
else if (((nidx_term[0] == 2 && nidx_term[1] == 3) || (nidx_term[0] == 3 && nidx_term[1] == 2)) && nidx_term[2] == 1) {
term.blas_kernel = DENSE_3D;
// term.blas_kernel = RECURSIVE_LOOP;
if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "DENSE_3D" << std::endl;
}
else {
Expand Down Expand Up @@ -1127,22 +1171,22 @@ namespace CTF_int {
break;
}
}
/*
if (nidx_term[2] == 0) {
spttn_print << "term_id: " << i << " nidx_term[0]: " << nidx_term[0] << " nidx_term[1]: " << nidx_term[1] << " nidx_term[2]: " << nidx_term[2] << " inner_idx: " << terms[i].inner_idx << " reset_idx: " << terms[i].reset_idx << std::endl;
IASSERT(terms[i].inner_idx != -1);
/*
for i:
for a:
buf2 += buf[a] * U[a,i]
Z_ijk += buf2 * T_ijk
the first term has inner_idx of 1, but the output is a buffer that needs to be accumulated, so reset the buffer at 'a' in this case
*/
IASSERT(terms[i].reset_idx != -1);
// treat this as a RECURSIVE_LOOP and not as SCALAR
terms[i].blas_kernel = RECURSIVE_LOOP;
spttn_print << "term_id: " << i << " blas_kernel: " << "RECURSIVE_LOOP" << std::endl;
continue;
}
*/
select_blas_kernel<dtype>(i, in_term_idx, nidx_term, num_idx, len_idx, terms, num_indices, idx_Bs, nBs, rank);
}
else {
Expand Down

0 comments on commit 93d853e

Please sign in to comment.