Skip to content

Commit

Permalink
handle a case where SPARSE_xAXPY writes to an input tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
raghavendrak committed Jul 12, 2024
1 parent 1892ba7 commit 0a39bfc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
19 changes: 19 additions & 0 deletions src/spttn_cyclops/execute_kernel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,19 @@ namespace CTF_int {
}
}
break;
case SPARSE_xAXPY_OP_NOT_BUFFER: {
double * dX = (double *)Bs[term.X];
for (int64_t it = tree_pt_st; it < tree_pt_en; it++) {
double alpha = A_tree->dt[it].d;
int64_t idx_i = A_tree->idx[0][it];
double * dY = (double *)((double *)Bs[term.Y] + lda_Bs[term.Y][idx] * idx_i);
#pragma omp simd
for (int64_t i = 0; i < term.N; i++){
dY[i] += alpha * dX[i];
}
}
}
break;
case xAXPY: {
double * dX = (double *)Bs[term.X];
double * dY = (double *)Bs[term.Y];
Expand Down Expand Up @@ -464,6 +477,9 @@ namespace CTF_int {
for (int i = 0; i < nterms; i++) {
if (terms[i].tbuffer_order == -1) continue;
lda_Bs[nBs+i] = (int64_t *) CTF_int::alloc(sizeof(int64_t) * num_indices);
for (int j = 0; j < num_indices; j++) {
lda_Bs[nBs+i][j] = 0; // a case where this is queried and the index is not in the buffer; for example, when SPARSE_xAXPY is called with a buffer
}
//(i,a): i is the fastest moving index
lda_Bs[nBs+i][terms[i].idx_tbuffer[0]] = 1;
for (int j = 1; j < terms[i].tbuffer_order; j++) {
Expand Down Expand Up @@ -513,6 +529,9 @@ namespace CTF_int {
for (int i = 0; i < nterms; i++) {
if (terms[i].tbuffer_order == -1) continue;
lda_Bs[nBs+i] = (int64_t *) CTF_int::alloc(sizeof(int64_t) * num_indices);
for (int j = 0; j < num_indices; j++) {
lda_Bs[nBs+i][j] = 0; // a case where this is queried and the index is not in the buffer; for example, when SPARSE_xAXPY is called with a buffer
}
//(i,a): i is the fastest moving index
lda_Bs[nBs+i][terms[i].idx_tbuffer[0]] = 1;
for (int j = 1; j < terms[i].tbuffer_order; j++) {
Expand Down
40 changes: 28 additions & 12 deletions src/spttn_cyclops/prepare_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ namespace CTF_int {
};

enum BLAS_OPERANDS {MAIN_TENSOR, PREV_TERM_BUF, CURRENT_TERM_BUF, INP_B, INTERMEDIATE_TENSOR};
enum BREAK_REC {NOT_SET, SCALAR, SPARSE, SPARSE_xAXPY, DENSE, DENSE_xAXPY_2D, DENSE_xAXPY_3D, DENSE_STRIDED, DENSE_3D, DENSE_3D_TO_xAXPY, xAXPY, xVEC_MUL, xGER, xGER_TO_xAXPY, xGEMV, RECURSIVE_LOOP};
enum BREAK_REC {NOT_SET, SCALAR, SPARSE, SPARSE_xAXPY, SPARSE_xAXPY_OP_NOT_BUFFER, DENSE, DENSE_xAXPY_2D, DENSE_xAXPY_3D, DENSE_STRIDED, DENSE_3D, DENSE_3D_TO_xAXPY, xAXPY, xVEC_MUL, xGER, xGER_TO_xAXPY, xGEMV, RECURSIVE_LOOP};

template <typename dtype>
const char* enumToStr(BREAK_REC br) {
Expand All @@ -153,6 +153,7 @@ namespace CTF_int {
case SCALAR: return "SCALAR";
case SPARSE: return "SPARSE";
case SPARSE_xAXPY: return "SPARSE_xAXPY";
case SPARSE_xAXPY_OP_NOT_BUFFER: return "SPARSE_xAXPY_NOT_BUFFER";
case DENSE: return "DENSE";
case DENSE_xAXPY_2D: return "DENSE_xAXPY_2D";
case DENSE_xAXPY_3D: return "DENSE_xAXPY_3D";
Expand Down Expand Up @@ -237,12 +238,15 @@ namespace CTF_int {
}
if (term.X == -1) {
IASSERT(term.Bs_in_term[nBs] == true);
// interchange
// interchange: takes care of the case where the sparse tensor is contracted not in the first term
term.X = nBs + term.inp_buf_id;
term.ALPHA = -1;
}
if (i == (nterms - 1)) term.Y = nBs - 1;
else term.Y = nBs + i;
if (i == (nterms-1)) {
std::cout << "term.X: " << term.X << " term.ALPHA: " << term.ALPHA << " term.Y: " << term.Y << std::endl;
}
}
}
break;
Expand Down Expand Up @@ -378,18 +382,30 @@ namespace CTF_int {
break;
}
}
IASSERT(i != (nterms-1));
term.ALPHA = -1;
term.X = term.blas_B_ids[1];
int idx = term.idx_tbuffer[0];
term.Y = i + nBs;
if (idx_Bs[term.X][0] != term.idx_tbuffer[0]) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
int idx = -1;
if (i != (nterms-1)) {
term.X = term.blas_B_ids[1]; // an input tensor
term.Y = nBs + i; // buffer
idx = term.idx_tbuffer[0];
if (idx_Bs[term.X][0] != term.idx_tbuffer[0]) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
}
}
else {
term.blas_kernel = SPARSE_xAXPY_OP_NOT_BUFFER;
term.X = nBs + term.inp_buf_id; // an intermediate tensor (one of the other input in the pairwise contraction is the main tensor)
term.Y = nBs - 1;
idx = idx_Bs[term.Y][0];
if (idx != terms[term.inp_buf_id].idx_tbuffer[0]) {
term.blas_kernel = RECURSIVE_LOOP;
i--;
break;
}
}
// assert failure: strided access not supported; INCX == 1 and INCY == 1
IASSERT(idx_Bs[term.X][0] == term.idx_tbuffer[0]);
IASSERT(idx != -1);
term.INCX = lda_Bs[term.X][idx];
term.INCY = lda_Bs[term.Y][idx];
IASSERT(term.INCX == 1 && term.INCY == 1);
Expand Down

0 comments on commit 0a39bfc

Please sign in to comment.