diff --git a/src/spttn_cyclops/execute_kernel.cxx b/src/spttn_cyclops/execute_kernel.cxx index a98349a3..31734382 100644 --- a/src/spttn_cyclops/execute_kernel.cxx +++ b/src/spttn_cyclops/execute_kernel.cxx @@ -312,8 +312,14 @@ namespace CTF_int { double * dY = (double *)Bs[term.Y]; double * dA = (double *)Bs[term.A]; double BETA = 1.; - const char TRANS = 'N'; - CTF_BLAS::DGEMV(&TRANS, &term.M, &term.N, &term.ALPHA, dA, &term.LDA, dX, &term.INCX, &BETA, dY, &term.INCY); + CTF_BLAS::DGEMV(&term.TRANS, &term.M, &term.N, &term.ALPHA, dA, &term.LDA, dX, &term.INCX, &BETA, dY, &term.INCY); + } + break; + case xDOT: { + double * dX = (double *)Bs[term.X]; + double * dY = (double *)Bs[term.Y]; + double * dot = (double *)Bs[(int)term.ALPHA]; + *dot = CTF_BLAS::DDOT(&term.N, dX, &term.INCX, dY, &term.INCY); } break; default: { diff --git a/src/spttn_cyclops/execute_kernel.h b/src/spttn_cyclops/execute_kernel.h index 795c565c..1ac6e2ec 100644 --- a/src/spttn_cyclops/execute_kernel.h +++ b/src/spttn_cyclops/execute_kernel.h @@ -50,6 +50,7 @@ namespace CTF_int{ double BETA; int M; int N; + char TRANS; bool * dense_sp_loop; int dense_sp_loop_in_term; diff --git a/src/spttn_cyclops/prepare_kernel.h b/src/spttn_cyclops/prepare_kernel.h index ef6d2b50..a2aa7eab 100644 --- a/src/spttn_cyclops/prepare_kernel.h +++ b/src/spttn_cyclops/prepare_kernel.h @@ -145,7 +145,7 @@ namespace CTF_int { }; enum BLAS_OPERANDS {MAIN_TENSOR, PREV_TERM_BUF, CURRENT_TERM_BUF, INP_B, INTERMEDIATE_TENSOR}; - enum BREAK_REC {NOT_SET, SCALAR, SPARSE, SPARSE_xAXPY, SPARSE_xAXPY_OP_NOT_BUFFER, DENSE, DENSE_xAXPY_2D, DENSE_xAXPY_3D, DENSE_STRIDED, DENSE_3D, DENSE_3D_TO_xAXPY, xAXPY, xVEC_MUL, xGER, xGER_TO_xAXPY, xGEMV, RECURSIVE_LOOP}; + enum BREAK_REC {NOT_SET, SCALAR, SPARSE, SPARSE_xAXPY, SPARSE_xAXPY_OP_NOT_BUFFER, DENSE, DENSE_xAXPY_2D, DENSE_xAXPY_3D, DENSE_STRIDED, DENSE_3D, DENSE_3D_TO_xAXPY, xAXPY, xVEC_MUL, xGER, xDOT, xGER_TO_xAXPY, xGEMV, RECURSIVE_LOOP}; template const char* enumToStr(BREAK_REC br) { @@ -163,6 +163,7 @@ namespace CTF_int { case xAXPY: return "xAXPY"; case xVEC_MUL: return "xVEC_MUL"; case xGER: return "xGER"; + case xDOT: return "xDOT"; case xGEMV: return "xGEMV"; case RECURSIVE_LOOP: return "RECURSIVE_LOOP"; default: return "Unknown value"; @@ -432,6 +433,10 @@ namespace CTF_int { term.INCY = lda_Bs[term.Y][idx]; IASSERT(term.INCX == 1 && term.INCY == 1); term.N = len_idx[idx]; + if (term.Y == nBs - 1) { + // the output has a sparse index in this case + term.blas_kernel = SPARSE_xAXPY_OP_NOT_BUFFER; + } /* if (i != (nterms-1)) { term.X = term.blas_B_ids[1]; // an input tensor @@ -688,15 +693,29 @@ namespace CTF_int { idx_Y = term.Y >= nBs ? terms[term.Y-nBs].idx_tbuffer[0] : idx_Bs[term.Y][0]; // input int idx_X1 = term.blas_B_ids[0] >= nBs ? terms[term.blas_B_ids[0]-nBs].idx_tbuffer[0] : idx_Bs[term.blas_B_ids[0]][0]; + int idx_Y1 = term.blas_B_ids[0] >= nBs ? terms[term.blas_B_ids[0]-nBs].idx_tbuffer[1] : idx_Bs[term.blas_B_ids[0]][1]; int idx_X2 = term.blas_B_ids[1] >= nBs ? terms[term.blas_B_ids[1]-nBs].idx_tbuffer[0] : idx_Bs[term.blas_B_ids[1]][0]; - if (idx_X1 == idx_Y) { + int idx_Y2 = term.blas_B_ids[1] >= nBs ? terms[term.blas_B_ids[1]-nBs].idx_tbuffer[1] : idx_Bs[term.blas_B_ids[1]][1]; + if (idx_X1 == idx_Y || idx_Y1 == idx_Y) { // first input tensor is A term.A = term.blas_B_ids[0]; - int idx_Y1 = term.blas_B_ids[0] >= nBs ? terms[term.blas_B_ids[0]-nBs].idx_tbuffer[1] : idx_Bs[term.blas_B_ids[0]][1]; - if (idx_Y1 != idx_X2) { - term.blas_kernel = RECURSIVE_LOOP; - i--; - break; + if (idx_X1 == idx_Y) { + if (idx_Y1 != idx_X2) { + term.blas_kernel = RECURSIVE_LOOP; + i--; + break; + } + // no transpose + term.TRANS = 'N'; + } + if (idx_Y1 == idx_Y) { + if (idx_X1 != idx_X2) { + term.blas_kernel = RECURSIVE_LOOP; + i--; + break; + } + // transpose A + term.TRANS = 'T'; } term.M = len_idx[idx_X1]; term.N = len_idx[idx_Y1]; @@ -706,11 +725,23 @@ namespace CTF_int { else { // second input tensor is A term.A = term.blas_B_ids[1]; - int idx_Y2 = term.blas_B_ids[1] >= nBs ? terms[term.blas_B_ids[1]-nBs].idx_tbuffer[1] : idx_Bs[term.blas_B_ids[1]][1]; - if (idx_Y2 != idx_X1) { - term.blas_kernel = RECURSIVE_LOOP; - i--; - break; + if (idx_X2 == idx_Y) { + if (idx_Y2 != idx_X1) { + term.blas_kernel = RECURSIVE_LOOP; + i--; + break; + } + // no transpose + term.TRANS = 'N'; + } + if (idx_Y2 == idx_Y) { + if (idx_X2 != idx_X1) { + term.blas_kernel = RECURSIVE_LOOP; + i--; + break; + } + // transpose A + term.TRANS = 'T'; } term.M = len_idx[idx_X2]; term.N = len_idx[idx_Y2]; @@ -721,6 +752,16 @@ namespace CTF_int { term.INCY = lda_Bs[term.Y][idx_Y]; } break; + case xDOT: { + // dot <- x^T * y + term.ALPHA = term.blas_B_ids[2]; + term.X = term.blas_B_ids[0]; + term.Y = term.blas_B_ids[1]; + term.INCX = lda_Bs[term.X][term.blas_idx]; + term.INCY = lda_Bs[term.Y][term.blas_idx]; + term.N = len_idx[term.blas_idx]; + } + break; default: break; } @@ -750,7 +791,7 @@ namespace CTF_int { const int rank) { debug_spttn_cyclops spttn_print; - IASSERT(nidx_term[2] > 0); + // IASSERT(nidx_term[2] > 0); contraction_terms & term = terms[term_id]; bool recursive_loop = false; if (recursive_loop == true) { @@ -765,6 +806,7 @@ namespace CTF_int { if (term.sparse_idx != -1) { term.blas_idx = term.sparse_idx; term.blas_kernel = SPARSE_xAXPY; + // term.blas_kernel = RECURSIVE_LOOP; if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "SPARSE_xAXPY" << std::endl; } else { @@ -772,7 +814,7 @@ namespace CTF_int { if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "xAXPY" << std::endl; } } - else if (nidx_term[0] == 1 && nidx_term[1] == 1) { + else if (nidx_term[0] == 1 && nidx_term[1] == 1 && nidx_term[2] == 1) { int idx_X, idx_Y; idx_X = term.blas_B_ids[0] >= nBs ? terms[term.blas_B_ids[0]-nBs].idx_tbuffer[0] : idx_Bs[term.blas_B_ids[0]][0]; idx_Y = term.blas_B_ids[1] >= nBs ? terms[term.blas_B_ids[1]-nBs].idx_tbuffer[0] : idx_Bs[term.blas_B_ids[1]][0]; @@ -793,6 +835,10 @@ namespace CTF_int { term.blas_kernel = xVEC_MUL; if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "xVEC_MUL" << std::endl; } + else if (nidx_term[0] == 1 && nidx_term[1] == 1 && nidx_term[2] == 0) { + term.blas_kernel = xDOT; + if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "xDOT" << std::endl; + } else { // TODO: ttmc_o3_allm rkji rskj trks // FIXME: should have been handled in process_inner_ids @@ -810,7 +856,6 @@ namespace CTF_int { if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "RECURSIVE_LOOP" << std::endl; } else if ((nidx_term[0] == 1 && nidx_term[1] == 2 && nidx_term[2] == 1) || (nidx_term[0] == 2 && nidx_term[1] == 1 && nidx_term[2] == 1)) { - // TODO: tucker_solve TTTP term 1 a <- abc bj term.blas_kernel = xGEMV; if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "xGEMV" << std::endl; } @@ -840,7 +885,6 @@ namespace CTF_int { } else if (((nidx_term[0] == 2 && nidx_term[1] == 3) || (nidx_term[0] == 3 && nidx_term[1] == 2)) && nidx_term[2] == 1) { term.blas_kernel = DENSE_3D; - // term.blas_kernel = RECURSIVE_LOOP; if (rank == 0) spttn_print << "term_id: " << term_id << " blas_kernel: " << "DENSE_3D" << std::endl; } else { @@ -1127,22 +1171,22 @@ namespace CTF_int { break; } } + /* if (nidx_term[2] == 0) { spttn_print << "term_id: " << i << " nidx_term[0]: " << nidx_term[0] << " nidx_term[1]: " << nidx_term[1] << " nidx_term[2]: " << nidx_term[2] << " inner_idx: " << terms[i].inner_idx << " reset_idx: " << terms[i].reset_idx << std::endl; IASSERT(terms[i].inner_idx != -1); - /* for i: for a: buf2 += buf[a] * U[a,i] Z_ijk += buf2 * T_ijk the first term has inner_idx of 1, but the output is a buffer that needs to be accumulated, so reset the buffer at 'a' in this case - */ IASSERT(terms[i].reset_idx != -1); // treat this as a RECURSIVE_LOOP and not as SCALAR terms[i].blas_kernel = RECURSIVE_LOOP; spttn_print << "term_id: " << i << " blas_kernel: " << "RECURSIVE_LOOP" << std::endl; continue; } + */ select_blas_kernel(i, in_term_idx, nidx_term, num_idx, len_idx, terms, num_indices, idx_Bs, nBs, rank); } else {