Skip to content

Commit

Permalink
added support for apple accelerate
Browse files Browse the repository at this point in the history
  • Loading branch information
HasKha committed Apr 12, 2024
1 parent b1fba68 commit 8421450
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 41 deletions.
30 changes: 25 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,20 @@ option(NASOQ_USE_OPENMP "Use OpenMP" ON)
option(NASOQ_WITH_MATLAB "Build NASOQ Matlab interface" OFF)

set(NASOQ_BLAS_BACKEND "MKL" CACHE STRING "BLAS implementation for NASOQ to use")
set_property(CACHE NASOQ_BLAS_BACKEND PROPERTY STRINGS MKL OpenBLAS)
set_property(CACHE NASOQ_BLAS_BACKEND PROPERTY STRINGS MKL OpenBLAS Accelerate)

if(${NASOQ_BLAS_BACKEND} STREQUAL "MKL")
set(NASOQ_USE_BLAS_MKL ON)
set(NASOQ_USE_BLAS_OpenBLAS OFF)
set(NASOQ_USE_BLAS_MKL ON)
set(NASOQ_USE_BLAS_OpenBLAS OFF)
set(NASOQ_USE_BLAS_Accelerate OFF)
elseif(${NASOQ_BLAS_BACKEND} STREQUAL "OpenBLAS")
set(NASOQ_USE_BLAS_MKL OFF)
set(NASOQ_USE_BLAS_OpenBLAS ON)
set(NASOQ_USE_BLAS_MKL OFF)
set(NASOQ_USE_BLAS_OpenBLAS ON)
set(NASOQ_USE_BLAS_Accelerate OFF)
elseif(${NASOQ_BLAS_BACKEND} STREQUAL "Accelerate")
set(NASOQ_USE_BLAS_MKL OFF)
set(NASOQ_USE_BLAS_OpenBLAS OFF)
set(NASOQ_USE_BLAS_Accelerate ON)
else()
message(FATAL_ERROR "unrecognized value for `NASOQ_BLAS_BACKEND` option: '${NASOQ_BLAS_BACKEND}'")
endif()
Expand Down Expand Up @@ -80,6 +86,11 @@ elseif(NASOQ_USE_BLAS_OpenBLAS)
if(openblas_WITHOUT_LAPACK AND NOT NASOQ_USE_CLAPACK)
message(FATAL_ERROR "cannot build LAPACKE for use with OpenBLAS (maybe you don't have a Fortran compiler?) try setting `NASOQ_USE_CLAPACK` to `ON` to use a C-language alternative.")
endif()
elseif(NASOQ_USE_BLAS_Accelerate)
include(accelerate)
if (NASOQ_USE_CLAPACK)
message(FATAL_ERROR "Accelerate already provides LAPACK. Please set `NASOQ_USE_CLAPACK` to `OFF`.")
endif()
endif()

if(NASOQ_USE_CLAPACK)
Expand Down Expand Up @@ -184,6 +195,15 @@ if(NASOQ_USE_BLAS_MKL)
elseif(NASOQ_USE_BLAS_OpenBLAS)
target_link_libraries(nasoq PRIVATE OpenBLAS::OpenBLAS)
target_compile_definitions(nasoq PRIVATE "OPENBLAS")
elseif(NASOQ_USE_BLAS_Accelerate)
target_link_libraries(nasoq PRIVATE BLAS::BLAS)
target_link_libraries(nasoq PRIVATE LAPACK::LAPACK)
target_compile_definitions(nasoq PRIVATE "ACCELERATE")
target_compile_definitions(nasoq PUBLIC "ACCELERATE_NEW_LAPACK")
target_sources(nasoq PRIVATE
"src/clapacke/clapacke_dlapmt.cpp"
"src/clapacke/clapacke_dsytrf.cpp"
)
endif()

if(NASOQ_WITH_EIGEN)
Expand Down
12 changes: 12 additions & 0 deletions cmake/third_party/accelerate.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
if(TARGET BLAS::BLAS)
return()
endif()

if("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "arm64" OR "${CMAKE_OSX_ARCHITECTURES}" MATCHES "arm64")
# Use Accelerate on macOS M1
set(BLA_VENDOR Apple)
find_package(BLAS REQUIRED)
find_package(LAPACK REQUIRED)
else()
message(FATAL_ERROR "Accelerate is only support on apple M1")
endif()
10 changes: 9 additions & 1 deletion include/nasoq/clapacke/clapacke.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ namespace clapacke {
#define LAPACK_COL_MAJOR 102
#endif

#if defined(NASOQ_USE_CLAPACK)
using clapack_int = long int;
using clapack_logical = long int;
#elif defined(ACCELERATE)
using clapack_int = int;
using clapack_logical = int;
#endif


/**
Expand All @@ -37,7 +42,7 @@ int LAPACKE_dsytrf(
clapack_int lda,
clapack_int* ipiv
);

#if defined(NASOQ_USE_CLAPACK)
int LAPACKE_dsytrf(
int matrix_layout,
char uplo,
Expand All @@ -46,6 +51,7 @@ int LAPACKE_dsytrf(
int lda,
int* ipiv
);
#endif


/**
Expand All @@ -69,6 +75,7 @@ int LAPACKE_dlapmt(
clapack_int* k
);

#if defined(NASOQ_USE_CLAPACK)
int LAPACKE_dlapmt(
int matrix_layout,
int forwrd,
Expand All @@ -78,6 +85,7 @@ int LAPACKE_dlapmt(
int ldx,
int* k
);
#endif

} // namespace clapacke
} // namespace nasoq
18 changes: 16 additions & 2 deletions include/nasoq/common/Sym_BLAS.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@
#endif
#include "openblas/cblas.h"
// #endif
#else
#elif defined(MKL_BLAS)
#include "mkl.h"
#include <mkl_blas.h>
#include <mkl_lapacke.h>
#elif defined(ACCELERATE)
#include <Accelerate/Accelerate.h>
#include "nasoq/clapacke/clapacke.h"
#else
#error "unknown BLAS backend"
#endif
namespace nasoq {
# define VEC_SCAL(n, a, x, u){ \
Expand All @@ -37,18 +42,27 @@ namespace nasoq {
}

#ifdef OPENBLAS
// Note: those are not actually used
#define SYM_DGEMM dgemm_
#define SYM_DTRSM dtrsm_
#define SYM_DGEMV dgemv_
#define SYM_DSCAL dscal_
#define SET_BLAS_THREAD(t) (openblas_set_num_threads(t))
#else

#elif defined(MKL_BLAS)
// Note: those could be replaced in-line instead
#define SYM_DGEMM dgemm
#define SYM_DTRSM dtrsm
#define SYM_DGEMV dgemv
#define SYM_DSCAL dscal

#define SET_BLAS_THREAD(t) (MKL_Domain_Set_Num_Threads(t, MKL_DOMAIN_BLAS))

#elif defined(ACCELERATE)
#define SET_BLAS_THREAD(t) ((void)0)
#else
#error "unknown BLAS backend"

#endif


Expand Down
3 changes: 3 additions & 0 deletions src/QP/linear_solver_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ namespace nasoq {
#elif defined(OPENBLAS)
num_thread = openblas_get_num_procs();
openblas_set_num_threads(1);
#elif defined(ACCELERATE)
// number of threads must be set at compile time in accelerate,
// using VECLIB_MAXIMUM_THREADS
#else
#error couldn't determine BLAS implementation
#endif
Expand Down
17 changes: 16 additions & 1 deletion src/clapacke/clapacke_dlapmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

#include <vector>

#if defined(NASOQ_USE_CLAPACK)
extern "C" {
#include "f2c.h"
#include "clapack.h"
}
#elif defined(ACCELERATE)
#include <Accelerate/Accelerate.h>
#endif


namespace nasoq {
Expand All @@ -21,7 +25,6 @@ inline const T& maximum(const T& a, const T& b) {
return a >= b ? a : b;
}


/**
* DLAPMT rearranges the columns of the M by N matrix X as specified
* by the permutation K(1),K(2),...,K(N) of the integers 1,...,N.
Expand Down Expand Up @@ -66,18 +69,28 @@ int LAPACKE_dlapmt(
if(ldx < n) return -6;
std::vector<double> x_t(ldx_t * maximum<clapack_int>(1,n));
transpose_into(x_t.data(), ldx_t, x, ldx, m, n, matrix_layout);
#if defined(NASOQ_USE_CLAPACK)
clapack_int info = dlapmt_(&forwrd, &m, &n, x_t.data(), &ldx_t, k);
if(info < 0) return info-1;
#else
dlapmt_(&forwrd, &m, &n, x_t.data(), &ldx_t, k);
#endif
transpose_into(x, ldx, x_t.data(), ldx_t, m, n, LAPACK_COL_MAJOR);
} else if(LAPACK_COL_MAJOR == matrix_layout) {
#if defined(NASOQ_USE_CLAPACK)
clapack_int info = dlapmt_(&forwrd, &m, &n, x, &ldx, k);
if(info < 0) return info-1;
#else
dlapmt_(&forwrd, &m, &n, x, &ldx, k);
#endif
} else {
return -1; // argument 1 has an illegal value
}
return 0;
}

#if defined(NASOQ_USE_CLAPACK)

int LAPACKE_dlapmt(
int matrix_layout,
int forwrd,
Expand Down Expand Up @@ -105,5 +118,7 @@ int LAPACKE_dlapmt(
return info;
}

#endif

} // namespace clapacke
} // namespace nasoq
8 changes: 8 additions & 0 deletions src/clapacke/clapacke_dsytrf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

#include <vector>

#if defined(NASOQ_USE_CLAPACK)
extern "C" {
#include "f2c.h"
#include "clapack.h"
}
#elif defined(ACCELERATE)
#include <Accelerate/Accelerate.h>
#endif


namespace nasoq {
Expand Down Expand Up @@ -119,6 +123,8 @@ int LAPACKE_dsytrf(
return 0;
}

#if defined(NASOQ_USE_CLAPACK)

int LAPACKE_dsytrf(
int matrix_layout,
char uplo,
Expand All @@ -144,5 +150,7 @@ int LAPACKE_dsytrf(
return info;
}

#endif

} // namespace clapacke
} // namespace nasoq
14 changes: 8 additions & 6 deletions src/common/Sym_BLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace nasoq {
//#ifdef OPENBLAS
//cblas_dscal(tmp_dim, sca_tmp, tmp1, iun);
//#else
#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dscal(tmp_dim,sca_tmp,tmp1,iun);
#else
SYM_DSCAL(&tmp_dim, &sca_tmp, tmp1, &iun);
Expand All @@ -69,10 +69,12 @@ cblas_dscal(tmp_dim,sca_tmp,tmp1,iun);
double *tmp1_stride = tmp1 + stride;
/* std::cout<<dimx<<":"<<diag<<":"<<*tmp1<<":"<<iun<<":"<<
*tmp1_stride<<":"<<stride<<" : \n";*/
#ifdef OPENBLAS
#if defined(OPENBLAS)
blasint st = stride;
cblas_dsyr(CblasColMajor, CblasLower, dimx, diag, tmp1, iun, tmp1_stride, st); // ?syr Performs a rank-1 update of a symmetric matrix.
// dsyr_("L", &dimx, &diag, tmp1, &iun, tmp1_stride, &st); // ?syr Performs a rank-1 update of a symmetric matrix.
#elif defined(ACCELERATE)
cblas_dsyr(CblasColMajor, CblasLower, dimx, diag, tmp1, iun, tmp1_stride, stride); // ?syr Performs a rank-1 update of a symmetric matrix.
#else
dsyr("L", &dimx, &diag, tmp1, &iun, tmp1_stride, &stride); // ?syr Performs a rank-1 update of a symmetric matrix.
#endif
Expand Down Expand Up @@ -499,7 +501,7 @@ cblas_dscal(tmp_dim,sca_tmp,tmp1,iun);
if (D[i + lda_d] == 0) { // simple scaling
assert(D[i] != 0);
double tmp = 1.0 / D[i];
#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dscal(n_rhs, tmp, rhs + i * lda, iun);
#else
SYM_DSCAL(&n_rhs, &tmp, rhs + i * lda, &iun);
Expand All @@ -516,7 +518,7 @@ cblas_dscal(tmp_dim,sca_tmp,tmp1,iun);
rhs[i * lda + j] = x1 * D[i + 1] - x2 * subdiag;
rhs[(i + 1) * lda + j] = x2 * D[i] - x1 * subdiag;
}
#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dscal(n_rhs, one_over_det, rhs + i * lda, iun);
cblas_dscal(n_rhs, one_over_det, rhs + (i + 1) * lda, iun);
#else
Expand All @@ -541,7 +543,7 @@ cblas_dscal(tmp_dim,sca_tmp,tmp1,iun);
if (D[i + lda_d] == 0) { // simple scaling
assert(D[i] != 0);
double tmp = 1.0 / D[i];
#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dscal(n_rhs, tmp, rhs + i * lda, iun);
#else
SYM_DSCAL(&n_rhs, &tmp, rhs + i * lda, &iun);
Expand All @@ -558,7 +560,7 @@ cblas_dscal(tmp_dim,sca_tmp,tmp1,iun);
rhs[i * lda + j] = x1 * D[i + 1] - x2 * subdiag;
rhs[(i + 1) * lda + j] = x2 * D[i] - x1 * subdiag;
}
#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dscal(n_rhs, one_over_det, rhs + i * lda, iun);
cblas_dscal(n_rhs, one_over_det, rhs + (i + 1) * lda, iun);
#else
Expand Down
4 changes: 2 additions & 2 deletions src/ldl/Serial_blocked_ldl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ namespace nasoq {
//#ifdef OPENBLAS
//#endif

#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dgemm(CblasColMajor,CblasNoTrans,CblasConjTrans, ndrow3, ndrow1, supWdts, 1.0, srcL, nSNRCur,
src, nSNRCur, 0.0, contribs+ndrow1, nSupRs);
#else
Expand Down Expand Up @@ -153,7 +153,7 @@ namespace nasoq {
*(++stCol) = tmp * *(++curCol);
}
}
#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dtrsm(CblasColMajor, CblasRight, CblasLower, CblasConjTrans, CblasNonUnit, rowNo, supWdt, 1.0,
trn_diag, supWdt, &cur[supWdt], nSupR);
#else
Expand Down
6 changes: 3 additions & 3 deletions src/ldl/Serial_blocked_ldl_02_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace nasoq {
bool ldl_left_sn_02_v2(int n, int *c, int *r, double *values, size_t *lC, int *lR, size_t *Li_ptr, double *lValues,
double *D, int *blockSet, int supNo, double *timing, int *aTree, int *cT, int *rT, int *col2Sup,
int super_max, int col_max, int &nbpivot, int *perm_piv, int *atree_sm, double threshold) {
#if defined(OPENBLAS) && defined(NASOQ_USE_CLAPACK)
#if (defined(OPENBLAS) && defined(NASOQ_USE_CLAPACK)) || defined(ACCELERATE)
using nasoq::clapacke::LAPACKE_dlapmt;
using nasoq::clapacke::LAPACKE_dsytrf;
#endif
Expand Down Expand Up @@ -102,7 +102,7 @@ namespace nasoq {
src = &lValues[lC[cSN] + lb];//first element of src supernode starting from row lb
double *srcL = &lValues[lC[cSN] + ub + 1];
blocked_2by2_mult(supWdts, nSupRs, &D[cSN], src, trn_diag, nSNRCur, n);
#ifdef OPENBLAS
#if defined (OPENBLAS) || defined(ACCELERATE)
cblas_dgemm(CblasColMajor,CblasNoTrans,CblasConjTrans, nSupRs, ndrow1, supWdts, 1.0, trn_diag, nSupRs,
src, nSNRCur, 0.0, contribs, nSupRs);
#else
Expand Down Expand Up @@ -150,7 +150,7 @@ namespace nasoq {
D[curCol + l] = cur[l + l * nSupR];
cur[l + l * nSupR] = 1.0;
}
#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dtrsm(CblasColMajor, CblasRight, CblasLower, CblasConjTrans, CblasUnit, rowNo, supWdt, 1.0,
cur, nSupR, &cur[supWdt], nSupR);
#else
Expand Down
6 changes: 3 additions & 3 deletions src/ldl/Serial_update_ldl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace nasoq {
double *D, int *blockSet, int supNo, double *timing, int *aTree, int *cT, int *rT,
int *col2Sup, std::vector<int> mod_indices, int super_max, int col_max, int &nbpivot,
int *perm_piv, int *atree_sm, int *ws_int, double *ws_dbl, double threshold) {
#if defined(OPENBLAS) && defined(NASOQ_USE_CLAPACK)
#if (defined(OPENBLAS) && defined(NASOQ_USE_CLAPACK)) || defined(ACCELERATE)
using nasoq::clapacke::LAPACKE_dlapmt;
using nasoq::clapacke::LAPACKE_dsytrf;
#endif
Expand Down Expand Up @@ -151,7 +151,7 @@ namespace nasoq {
src = &lValues[lC[cSN] + lb];//first element of src supernode starting from row lb
double *srcL = &lValues[lC[cSN] + ub + 1];
blocked_2by2_mult(supWdts, nSupRs, &D[cSN], src, trn_diag, nSNRCur, n);
#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dgemm(CblasColMajor,CblasNoTrans,CblasConjTrans, nSupRs, ndrow1, supWdts, 1.0, trn_diag, nSupRs,
src, nSNRCur, 0.0, contribs, nSupRs);
#else
Expand Down Expand Up @@ -211,7 +211,7 @@ namespace nasoq {
// std::cout<<"\n";
// /////

#ifdef OPENBLAS
#if defined(OPENBLAS) || defined(ACCELERATE)
cblas_dtrsm(CblasColMajor, CblasRight, CblasLower, CblasConjTrans, CblasUnit, rowNo, supWdt, 1.0,
cur, nSupR, &cur[supWdt], nSupR);
#else
Expand Down
Loading

0 comments on commit 8421450

Please sign in to comment.