diff --git a/src/spttn_cyclops/execute_kernel.cxx b/src/spttn_cyclops/execute_kernel.cxx index 5ab7ff21..51cb4985 100644 --- a/src/spttn_cyclops/execute_kernel.cxx +++ b/src/spttn_cyclops/execute_kernel.cxx @@ -177,6 +177,19 @@ namespace CTF_int { } } break; + case SPARSE_xAXPY_OP_NOT_BUFFER: { + double * dX = (double *)Bs[term.X]; + for (int64_t it = tree_pt_st; it < tree_pt_en; it++) { + double alpha = A_tree->dt[it].d; + int64_t idx_i = A_tree->idx[0][it]; + double * dY = (double *)((double *)Bs[term.Y] + lda_Bs[term.Y][idx] * idx_i); + #pragma omp simd + for (int64_t i = 0; i < term.N; i++){ + dY[i] += alpha * dX[i]; + } + } + } + break; case xAXPY: { double * dX = (double *)Bs[term.X]; double * dY = (double *)Bs[term.Y]; @@ -464,6 +477,9 @@ namespace CTF_int { for (int i = 0; i < nterms; i++) { if (terms[i].tbuffer_order == -1) continue; lda_Bs[nBs+i] = (int64_t *) CTF_int::alloc(sizeof(int64_t) * num_indices); + for (int j = 0; j < num_indices; j++) { + lda_Bs[nBs+i][j] = 0; // a case where this is queried and the index is not in the buffer; for example, when SPARSE_xAXPY is called with a buffer + } //(i,a): i is the fastest moving index lda_Bs[nBs+i][terms[i].idx_tbuffer[0]] = 1; for (int j = 1; j < terms[i].tbuffer_order; j++) { @@ -513,6 +529,9 @@ namespace CTF_int { for (int i = 0; i < nterms; i++) { if (terms[i].tbuffer_order == -1) continue; lda_Bs[nBs+i] = (int64_t *) CTF_int::alloc(sizeof(int64_t) * num_indices); + for (int j = 0; j < num_indices; j++) { + lda_Bs[nBs+i][j] = 0; // a case where this is queried and the index is not in the buffer; for example, when SPARSE_xAXPY is called with a buffer + } //(i,a): i is the fastest moving index lda_Bs[nBs+i][terms[i].idx_tbuffer[0]] = 1; for (int j = 1; j < terms[i].tbuffer_order; j++) { diff --git a/src/spttn_cyclops/prepare_kernel.h b/src/spttn_cyclops/prepare_kernel.h index b9c74b6a..4dcbb66a 100644 --- a/src/spttn_cyclops/prepare_kernel.h +++ b/src/spttn_cyclops/prepare_kernel.h @@ -144,7 +144,7 @@ namespace CTF_int { }; enum BLAS_OPERANDS {MAIN_TENSOR, PREV_TERM_BUF, CURRENT_TERM_BUF, INP_B, INTERMEDIATE_TENSOR}; - enum BREAK_REC {NOT_SET, SCALAR, SPARSE, SPARSE_xAXPY, DENSE, DENSE_xAXPY_2D, DENSE_xAXPY_3D, DENSE_STRIDED, DENSE_3D, DENSE_3D_TO_xAXPY, xAXPY, xVEC_MUL, xGER, xGER_TO_xAXPY, xGEMV, RECURSIVE_LOOP}; + enum BREAK_REC {NOT_SET, SCALAR, SPARSE, SPARSE_xAXPY, SPARSE_xAXPY_OP_NOT_BUFFER, DENSE, DENSE_xAXPY_2D, DENSE_xAXPY_3D, DENSE_STRIDED, DENSE_3D, DENSE_3D_TO_xAXPY, xAXPY, xVEC_MUL, xGER, xGER_TO_xAXPY, xGEMV, RECURSIVE_LOOP}; template const char* enumToStr(BREAK_REC br) { @@ -153,6 +153,7 @@ namespace CTF_int { case SCALAR: return "SCALAR"; case SPARSE: return "SPARSE"; case SPARSE_xAXPY: return "SPARSE_xAXPY"; + case SPARSE_xAXPY_OP_NOT_BUFFER: return "SPARSE_xAXPY_NOT_BUFFER"; case DENSE: return "DENSE"; case DENSE_xAXPY_2D: return "DENSE_xAXPY_2D"; case DENSE_xAXPY_3D: return "DENSE_xAXPY_3D"; @@ -237,12 +238,15 @@ namespace CTF_int { } if (term.X == -1) { IASSERT(term.Bs_in_term[nBs] == true); - // interchange + // interchange: takes care of the case where the sparse tensor is contracted not in the first term term.X = nBs + term.inp_buf_id; term.ALPHA = -1; } if (i == (nterms - 1)) term.Y = nBs - 1; else term.Y = nBs + i; + if (i == (nterms-1)) { + std::cout << "term.X: " << term.X << " term.ALPHA: " << term.ALPHA << " term.Y: " << term.Y << std::endl; + } } } break; @@ -378,18 +382,30 @@ namespace CTF_int { break; } } - IASSERT(i != (nterms-1)); term.ALPHA = -1; - term.X = term.blas_B_ids[1]; - int idx = term.idx_tbuffer[0]; - term.Y = i + nBs; - if (idx_Bs[term.X][0] != term.idx_tbuffer[0]) { - term.blas_kernel = RECURSIVE_LOOP; - i--; - break; + int idx = -1; + if (i != (nterms-1)) { + term.X = term.blas_B_ids[1]; // an input tensor + term.Y = nBs + i; // buffer + idx = term.idx_tbuffer[0]; + if (idx_Bs[term.X][0] != term.idx_tbuffer[0]) { + term.blas_kernel = RECURSIVE_LOOP; + i--; + break; + } + } + else { + term.blas_kernel = SPARSE_xAXPY_OP_NOT_BUFFER; + term.X = nBs + term.inp_buf_id; // an intermediate tensor (one of the other input in the pairwise contraction is the main tensor) + term.Y = nBs - 1; + idx = idx_Bs[term.Y][0]; + if (idx != terms[term.inp_buf_id].idx_tbuffer[0]) { + term.blas_kernel = RECURSIVE_LOOP; + i--; + break; + } } - // assert failure: strided access not supported; INCX == 1 and INCY == 1 - IASSERT(idx_Bs[term.X][0] == term.idx_tbuffer[0]); + IASSERT(idx != -1); term.INCX = lda_Bs[term.X][idx]; term.INCY = lda_Bs[term.Y][idx]; IASSERT(term.INCX == 1 && term.INCY == 1);