Skip to content

Commit

Permalink
template blas calls
Browse files Browse the repository at this point in the history
  • Loading branch information
raghavendrak committed Jul 23, 2024
1 parent 93d853e commit b6f8d5b
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 25 deletions.
68 changes: 68 additions & 0 deletions src/shared/blas_symbs.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,74 @@ namespace CTF_BLAS {
INST_GEMM(std::complex<double>,Z)
#undef INST_GEMM

template <typename dtype>
void gemv(const char * TRANS,
const int * M,
const int * N,
const dtype * ALPHA,
dtype * A,
const int * LDA,
dtype * X,
const int * INCX,
const dtype * BETA,
dtype * Y,
const int * INCY){
printf("CTF ERROR GEMV not available for this type.\n");
ASSERT(0);
assert(0);
}
#define INST_GEMV(dtype,s) \
template <> \
void gemv<dtype>(const char * a, \
const int * b, \
const int * c, \
const dtype * d, \
dtype * e, \
const int * f, \
dtype * g, \
const int * h, \
const dtype * i, \
dtype * j, \
const int * k){ \
s ## GEMV(a,b,c,d,e,f,g,h,i,j,k); \
}
INST_GEMV(float,S)
INST_GEMV(double,D)
INST_GEMV(std::complex<float>,C)
INST_GEMV(std::complex<double>,Z)
#undef INST_GEMV

template <typename dtype>
void ger(const int *m,
const int *n,
const dtype *alpha,
const dtype *x,
const int *incx,
const dtype *y,
const int *incy,
dtype *A,
const int *lda) {
printf("CTF ERROR: GER not available for this type.\n");
assert(0);
}

#define INST_GER(dtype, prefix) \
template <> \
void ger<dtype>(const int *a, \
const int *b, \
const dtype *c, \
const dtype *d, \
const int *e, \
const dtype *f, \
const int *g, \
dtype *h, \
const int *i) { \
prefix##GER(a, b, c, d, e, f, g, h, i); \
}
INST_GER(float, S)
INST_GER(double, D)
#undef INST_GER

template <typename dtype>
void syr(const char * UPLO ,
const int * N ,
Expand Down
131 changes: 108 additions & 23 deletions src/shared/blas_symbs.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#define DGEMM dgemm_
#define CGEMM cgemm_
#define ZGEMM zgemm_
#define SGEMV sgemv_
#define DGEMV dgemv_
#define CGEMV cgemv_
#define ZGEMV zgemv_
#define SGEMM_BATCH sgemm_batch_
#define DGEMM_BATCH dgemm_batch_
#define CGEMM_BATCH cgemm_batch_
Expand All @@ -31,14 +35,18 @@
#define SCOPY scopy_
#define DCOPY dcopy_
#define ZCOPY zcopy_
#define SGER sger_
#define DGER dger_
#define DGEMV dgemv_
#else
#define DDOT ddot
#define SGEMM sgemm
#define DGEMM dgemm
#define CGEMM cgemm
#define ZGEMM zgemm
#define SGEMV sgemv
#define DGEMV dgemv
#define CGEMV cgemv
#define ZGEMV zgemv
#define SGEMM_BATCH sgemm_batch
#define DGEMM_BATCH dgemm_batch
#define CGEMM_BATCH cgemm_batch
Expand All @@ -62,8 +70,8 @@
#define SCOPY scopy
#define DCOPY dcopy
#define ZCOPY zcopy
#define SGER sger
#define DGER dger
#define DGEMV dgemv
#endif


Expand Down Expand Up @@ -136,6 +144,7 @@ namespace CTF_BLAS {
const int *);



extern "C"
void SAXPY(const int * n,
float * dA,
Expand Down Expand Up @@ -194,6 +203,71 @@ namespace CTF_BLAS {
std::complex<double> * dY,
const int * incY);

extern "C"
void SGEMV(const char *,
const int *,
const int *,
const float *,
const float *,
const int *,
const float *,
const int *,
const float *,
float *,
const int *);

extern "C"
void DGEMV(const char *,
const int *,
const int *,
const double *,
const double *,
const int *,
const double *,
const int *,
const double *,
double *,
const int *);

extern "C"
void CGEMV(const char *,
const int *,
const int *,
const std::complex<float> *,
const std::complex<float> *,
const int *,
const std::complex<float> *,
const int *,
const std::complex<float> *,
std::complex<float> *,
const int *);

extern "C"
void ZGEMV(const char *,
const int *,
const int *,
const std::complex<double> *,
const std::complex<double> *,
const int *,
const std::complex<double> *,
const int *,
const std::complex<double> *,
std::complex<double> *,
const int *);

template <typename dtype>
void gemv(const char * TRANS,
const int * M,
const int * N,
const dtype * ALPHA,
dtype * A,
const int * LDA,
dtype * X,
const int * INCX,
const dtype * BETA,
dtype * Y,
const int * INCY);

extern "C"
void SSYR(const char * UPLO ,
const int * N ,
Expand Down Expand Up @@ -344,27 +418,38 @@ namespace CTF_BLAS {
const int * incX);

extern "C"
void DGER(const int * m,
const int * n,
const double * ALPHA,
const double * dX,
const int * incX,
const double * dY,
const int * incY,
double * dA,
const int * LDA);
extern "C"
void DGEMV(const char * TRANS,
const int * M,
const int * N,
const double * ALPHA,
const double * dA,
const int * LDA,
const double * dX,
const int * incX,
const double * BETA,
double * dY,
const int * incY);
void SGER(const int *,
const int *,
const float *,
const float *,
const int *,
const float *,
const int *,
float *,
const int *);

extern "C"
void DGER(const int *,
const int *,
const double *,
const double *,
const int *,
const double *,
const int *,
double *,
const int *);

template <typename dtype>
void ger(const int * m,
const int * n,
const dtype * alpha,
const dtype * x,
const int * incX,
const dtype * y,
const int * incY,
dtype * A,
const int * LDA);


#ifdef USE_BATCH_GEMM
extern "C"
Expand Down
4 changes: 2 additions & 2 deletions src/spttn_cyclops/execute_kernel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ namespace CTF_int {
case xGER: {
double * dX = (double *)Bs[term.X];
double * dY = (double *)Bs[term.Y];
CTF_BLAS::DGER(&term.M, &term.N, &term.ALPHA, dX, &term.INCX, dY, &term.INCY, (double *)Bs[term.A], &term.LDA);
CTF_BLAS::ger<double>(&term.M, &term.N, &term.ALPHA, dX, &term.INCX, dY, &term.INCY, (double *)Bs[term.A], &term.LDA);
}
break;
case DENSE_xAXPY_3D:
Expand Down Expand Up @@ -312,7 +312,7 @@ namespace CTF_int {
double * dY = (double *)Bs[term.Y];
double * dA = (double *)Bs[term.A];
double BETA = 1.;
CTF_BLAS::DGEMV(&term.TRANS, &term.M, &term.N, &term.ALPHA, dA, &term.LDA, dX, &term.INCX, &BETA, dY, &term.INCY);
CTF_BLAS::gemv<double>(&term.TRANS, &term.M, &term.N, &term.ALPHA, dA, &term.LDA, dX, &term.INCX, &BETA, dY, &term.INCY);
}
break;
case xDOT: {
Expand Down

0 comments on commit b6f8d5b

Please sign in to comment.