Skip to content

Commit

Permalink
Add single-precision support (CPU MEX)
Browse files Browse the repository at this point in the history
This adds single-precision support for CPU MEX files and updates the
MEX unit tests (MoDT.validateMex) accordingly.
  • Loading branch information
kqshan committed May 1, 2017
1 parent 9af7fc7 commit 103bf8d
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 83 deletions.
12 changes: 12 additions & 0 deletions @MoDT/bandPosSolve.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
% The MEX version directly calls the LAPACK routine for solving a banded
% positive definite matrix, whereas the MATLAB code involves an inefficient step
% of converting this to a compressed sparse column (CSC) representation.
%
% The MEX version also supports single-precision arithmetic, whereas the MATLAB
% code performs everything in double-precision, though it will cast x as single
% if A_bands or b is single-precision.

[p,N] = size(A_bands);

% MATLAB doesn't support single-precision sparse yet (as of R2017a)
A_bands = double(A_bands); b = double(b);

% Call sparse() to construct A
q = (p - 1) / 2; % Number of superdiagonals
i = bsxfun(@plus, (-q:q)', 1:N);
Expand All @@ -24,4 +31,9 @@
% Solve A*x = b
x = A \ b;

% Cast to single if A or b was single
if isa(A_bands,'single') || isa(b,'single')
x = single(x);
end

end
102 changes: 82 additions & 20 deletions @MoDT/mex_src/bandPosSolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
* A_bands [p x N] banded storage of a [N x N] positive definite matrix
* b [N x m] input vector
*
* This calls LAPACK dpbsv, and requires all inputs to be double-precision.
* This calls LAPACK dpbsv or spbsv, depending on the input datatype. If either
* A_bands or b is single-precision, then the result will be single-precision.
* This assumes that A is symmetric, and therefore uses only the first
* (p+1)/2 rows of A_bands.
*
Expand All @@ -18,6 +19,67 @@

#include "mex.h"
#include "lapack.h"
#include <algorithm>

/* Overload a single function for both single and double-precision data
*/
void pbsv(char *uplo, ptrdiff_t *n, ptrdiff_t *kd, ptrdiff_t *nrhs,
double *ab, ptrdiff_t *ldab, double *b, ptrdiff_t *ldb, ptrdiff_t *info)
{ dpbsv(uplo, n, kd, nrhs, ab, ldab, b, ldb, info); }
void pbsv(char *uplo, ptrdiff_t *n, ptrdiff_t *kd, ptrdiff_t *nrhs,
float *ab, ptrdiff_t *ldab, float *b, ptrdiff_t *ldb, ptrdiff_t *info)
{ spbsv(uplo, n, kd, nrhs, ab, ldab, b, ldb, info); }

/* Templated PBSV for both single and double
*
* Performs X = A \ X; A = chol(A,'lower');
*
* Inputs:
* N Size of A and #rows of X
* p #diagonals in A, i.e. 2*#superdiagonals + 1
* m #cols of X
* Inputs that are modified by this function:
* mx_A mxArray containing [N x N] symm. pos. def. matrix in banded storage
* mx_X mxArray containing [N x m] matrix
*/
template <typename T>
void posBandSolve(ptrdiff_t N, ptrdiff_t p, ptrdiff_t m, mxArray *mx_A, mxArray *mx_X)
{
// Variables used by the LAPACK routine
char uplo = 'U'; // Use the upper triangle of A
ptrdiff_t nSupDiag = (p-1)/2; // Number of superdiagonals in A
ptrdiff_t info = 0; // Status output for dpbsv
// Extract data pointers
T *A = static_cast<T*>(mxGetData(mx_A));
T *X = static_cast<T*>(mxGetData(mx_X));
// Call DPBSV or SPBSV as necessary
pbsv(&uplo, &N, &nSupDiag, &m, A, &p, X, &N, &info);
// Check that it succeeded
if (info != 0)
mexErrMsgIdAndTxt("MoDT:bandPosSolve:LAPACKError",
"LAPACK routine ?pbsv() exited with error");
}

/* Create a copy of an mxArray that is converted to single precision
*/
mxArray* copyToSingle(mxArray const *mx_X)
{
mxArray *mx_X_copy;
if (mxIsDouble(mx_X)) {
// Cast double to single
double *X = mxGetPr(mx_X);
ptrdiff_t M = mxGetM(mx_X);
ptrdiff_t N = mxGetN(mx_X);
mx_X_copy = mxCreateUninitNumericMatrix( M, N, mxSINGLE_CLASS, mxREAL );
std::copy( X, X+(M*N), static_cast<float*>(mxGetData(mx_X_copy)) );
} else if (mxIsSingle(mx_X)) {
// Source is already single-precision
mx_X_copy = mxDuplicateArray(mx_X);
} else {
// Throw an error because this is weird
mexErrMsgIdAndTxt("MoDT:bandPosSolve:TypeError","Unsupported datatype");
}
}

/* Main entry point into this mex file
* Inputs and outputs are arrays of mxArray pointers
Expand All @@ -32,32 +94,32 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
mxArray const *mx_AB = prhs[0];
ptrdiff_t p = mxGetM(mx_AB);
ptrdiff_t N = mxGetN(mx_AB);
if (!mxIsDouble(mx_AB) || (N==0))
mexErrMsgIdAndTxt(errId, "A_bands must be a [p x N] array of doubles");
bool AB_is_dbl = mxIsDouble(mx_AB);
if (!(AB_is_dbl || mxIsSingle(mx_AB)) || (N==0))
mexErrMsgIdAndTxt(errId, "A_bands must be a [p x N] array of real numbers");
if ((p%2==0) || (p>=2*N))
mexErrMsgIdAndTxt(errId, "A_bands is an invalid # of rows");
// b = input 1
mxArray const *mx_b = prhs[1];
ptrdiff_t m = mxGetN(mx_b);
if (!mxIsDouble(mx_b) || (mxGetM(mx_b)!=N) || (m==0))
mexErrMsgIdAndTxt(errId, "b must be an [N x m] array of doubles");
bool b_is_dbl = mxIsDouble(mx_b);
if (!(b_is_dbl || mxIsSingle(mx_b)) || (mxGetM(mx_b)!=N) || (m==0))
mexErrMsgIdAndTxt(errId, "b must be an [N x m] array of real numbers");

// dpbsv overwrites A_bands and B, so we need to duplicate them
mxArray *mx_AB_copy = mxDuplicateArray(mx_AB);
mxArray *mx_b_copy = mxDuplicateArray(mx_b);

// Variables used by the LAPACK routine
char uplo = 'U'; // Use the upper triangle of A
ptrdiff_t nSupDiag = (p-1)/2; // Number of superdiagonals in [bands]
ptrdiff_t info = 0; // Status output for dpbsv

// Call dpbsv to compute x = A \ b
double *AB = mxGetPr(mx_AB_copy);
double *b = mxGetPr(mx_b_copy);
dpbsv(&uplo, &N, &nSupDiag, &m, AB, &p, b, &N, &info);
// Check that it succeeded
if (info != 0)
mexErrMsgIdAndTxt(errId, "LAPACK routine dpbsv() exited with error");
mxArray *mx_AB_copy;
mxArray *mx_b_copy;
if (AB_is_dbl && b_is_dbl) {
// Both are double, perform the operation in double-precision
mx_AB_copy = mxDuplicateArray(mx_AB);
mx_b_copy = mxDuplicateArray(mx_b);
posBandSolve<double>(N, p, m, mx_AB_copy, mx_b_copy);
} else {
// Perform the operation in single-precision
mx_AB_copy = copyToSingle(mx_AB);
mx_b_copy = copyToSingle(mx_b);
posBandSolve<float>(N, p, m, mx_AB_copy, mx_b_copy);
}

// Cleanup
mxDestroyArray(mx_AB_copy);
Expand Down
171 changes: 118 additions & 53 deletions @MoDT/mex_src/sumFrames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* wzuY(:,:,t) = Y(:,n1:n2) * wzu(n1:n2,:);
* end
*
* This requires all the input arguments to be double-precision.
* Y and wzu must both be the same datatype (either single- or double-precision)
*
* Kevin Shan, 2016-06-06
*============================================================================*/
Expand All @@ -27,6 +27,100 @@
#include "mex.h"
#include "blas.h"

/* Overload a single function for both single and double-precision data
*/
void gemv(char *trans, ptrdiff_t *m, ptrdiff_t *n, double *alpha,
double const *a, ptrdiff_t *lda, double const *x, ptrdiff_t *incx,
double *beta, double *y, ptrdiff_t *incy)
{ dgemv(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); }
void gemv(char *trans, ptrdiff_t *m, ptrdiff_t *n, float *alpha,
float const *a, ptrdiff_t *lda, float const *x, ptrdiff_t *incx,
float *beta, float *y, ptrdiff_t *incy)
{ sgemv(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); }
void gemm(char *transa, char *transb, ptrdiff_t *m, ptrdiff_t *n, ptrdiff_t *k,
double *alpha, double const *a, ptrdiff_t *lda,
double const *b, ptrdiff_t *ldb,
double *beta, double *c, ptrdiff_t *ldc)
{ dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); }
void gemm(char *transa, char *transb, ptrdiff_t *m, ptrdiff_t *n, ptrdiff_t *k,
float *alpha, float const *a, ptrdiff_t *lda,
float const *b, ptrdiff_t *ldb,
float *beta, float *c, ptrdiff_t *ldc)
{ sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); }

/* A single-precision analogue to mxGetPr
*/
float* mxGetFloatPr( mxArray *mx_A )
{ return static_cast<float*>(mxGetData(mx_A)); }
float const* mxGetFloatPr( mxArray const *mx_A )
{ return static_cast<float const*>(mxGetData(mx_A)); }


/* Main routine for computing the weighted sum over data frames
*
* Inputs:
* D Number of feature space dimensions
* N Number of spikes
* K Number of clusters
* T Number of time frames
* Y [D x N] data matrix
* wzu [N x K] weights for each cluster x spike
* fsLim0 [T x 2] [first,last] data index (0-indexed) in each frame
* Outputs:
* wzuY [D x K x T] weighted sums for each frame (on GPU device)
* sumwzu [K x T] sums of weights in each frame (on GPU device)
*/
template <typename numeric_t>
void computeFrameSums(ptrdiff_t D, ptrdiff_t N, ptrdiff_t K, ptrdiff_t T,
numeric_t const *Y, numeric_t const *wzu,
std::vector<ptrdiff_t> const &fsLim0,
numeric_t *wzuY, numeric_t *sumwzu)
{
// Validate the fsLim indices and get the max # spikes per frame
char const * const errId = "MoDT:sumFramesGpu:InvalidInput";
ptrdiff_t maxCount = 0;
ptrdiff_t last_n2 = 0;
for (int t=0; t<T; t++) {
ptrdiff_t n1 = fsLim0[t];
ptrdiff_t n2 = fsLim0[t+T];
// Check that the indices are valid
if ((n1<0) || (n2<-1) || (n1>N) || (n2>=N) || (n2-n1 < -1))
mexErrMsgIdAndTxt(errId, "Invalid frame spike limits");
if ((t>0) & (n1 != last_n2+1))
mexErrMsgIdAndTxt(errId, "Non-consecutive frame spike limits");
last_n2 = n2;
// Get the difference
maxCount = std::max(maxCount, n2-n1+1);
}

// Vector of ones so we can sum the weights using dgemv
std::vector<numeric_t> ones(maxCount, 1.0);

// Variables used by the BLAS routines
char trans_N = 'N'; // Do no transpose matrix
char trans_T = 'T'; // Transpose matrix
numeric_t alpha = 1; // Scaling on Y*wzu' and wzu*ones
numeric_t beta = 0; // Scaling on wzuY and sum_wzu
ptrdiff_t incr = 1; // Vector increment of one

// For loop over time frames
for (int t=0; t<T; t++) {
// Get the first spike index and spike count for this time frame
ptrdiff_t n1 = ((ptrdiff_t) fsLim0[t]);
ptrdiff_t M = ((ptrdiff_t) fsLim0[t+T]) - n1 + 1;
if (M <= 0) continue;
// Call GEMM for wzuY = alpha*Y*wzu + beta*wzuY
gemm(&trans_N, &trans_N, &D, &K, &M,
&alpha, Y+(n1*D), &D, wzu+n1, &N,
&beta, wzuY+(t*D*K), &D);
// Call GEMV for sum_wzu = alpha*wzu'*ones + beta*sum_wzu
gemv(&trans_T, &M, &K, &alpha, wzu+n1, &N,
ones.data(), &incr, &beta, sumwzu+(t*K), &incr);
}
}



/* Main entry point into this mex file
* Inputs and outputs are arrays of mxArray pointers
*/
Expand All @@ -40,70 +134,41 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
mxArray const *mx_Y = prhs[0];
ptrdiff_t D = mxGetM(mx_Y);
ptrdiff_t N = mxGetN(mx_Y);
if (!mxIsDouble(mx_Y) || (D==0) || (N==0))
mexErrMsgIdAndTxt(errId, "Y must be a [D x N] array of doubles");
double *Y = mxGetPr(mx_Y);
mxClassID numericType = mxGetClassID(mx_Y);
if (!((numericType==mxDOUBLE_CLASS) || (numericType==mxSINGLE_CLASS)) || (D==0) || (N==0))
mexErrMsgIdAndTxt(errId, "Y must be a [D x N] array of real numbers");
// wzu = input 1
mxArray const *mx_wzu = prhs[1];
ptrdiff_t K = mxGetN(mx_wzu);
if (!mxIsDouble(mx_wzu) || (K==0) || (mxGetM(mx_wzu)!=N))
mexErrMsgIdAndTxt(errId, "wzu must be an [N x K] array of doubles");
double *wzu = mxGetPr(mx_wzu);
if ((mxGetClassID(mx_wzu)!=numericType) || (K==0) || (mxGetM(mx_wzu)!=N))
mexErrMsgIdAndTxt(errId, "wzu must be an [N x K] array of the same type as Y");
// f_spklim = input 2
mxArray const *mx_fsLim = prhs[2];
ptrdiff_t T = mxGetM(mx_fsLim);
if (!mxIsDouble(mx_fsLim) || (T==0) || (mxGetN(mx_fsLim)!=2))
mexErrMsgIdAndTxt(errId, "f_spklim must be a [T x 2] array of doubles");
double const *fsLim = mxGetPr(mx_fsLim);

// Validate the indices and determine the max spike count per frame
ptrdiff_t maxCount = 0;
for (ptrdiff_t t=0; t<T; t++) {
// Get the first and last index of the frame
ptrdiff_t n1 = ((ptrdiff_t) fsLim[t]) - 1;
ptrdiff_t n2 = ((ptrdiff_t) fsLim[t+T]) - 1;
// Check that the indices are valid
if ((n1<0) || (n2<-1) || (n1>N) || (n2>=N) || (n2-n1<-1))
mexErrMsgIdAndTxt(errId, "Invalid frame spike limits");
if ((t>0) & (n1 != fsLim[t-1+T]))
mexErrMsgIdAndTxt(errId, "Non-consecutive frame spike limits");
// Get the difference
ptrdiff_t count = n2 - n1 + 1;
if (count > maxCount)
maxCount = count;
}

// Copy the fsLim indices to a vector of 0-indexed integers
std::vector<ptrdiff_t> fsLim0(T*2);
std::transform(fsLim, fsLim+2*T, fsLim0.begin(),
[](double matlabIdx){ return static_cast<ptrdiff_t>(matlabIdx)-1; });

// Allocate memory for the outputs (initally filled with zeroes)
size_t dims[] = {(size_t)D, (size_t)K, (size_t)T};
mxArray *mx_wzuY = mxCreateNumericArray(3, dims, mxDOUBLE_CLASS, mxREAL);
double *wzuY = mxGetPr(mx_wzuY);
mxArray *mx_sumwzu = mxCreateDoubleMatrix(K, T, mxREAL);
double *sumwzu = mxGetPr(mx_sumwzu);
std::vector<size_t> dims_wzuY = {(size_t) D, (size_t) K, (size_t) T};
mxArray *mx_wzuY = mxCreateNumericArray(3, dims_wzuY.data(), numericType, mxREAL);
mxArray *mx_sumwzu = mxCreateNumericMatrix(K, T, numericType, mxREAL);

// Vector of ones so we can sum the weights using dgemv
std::vector<double> ones(maxCount);
std::fill(ones.begin(), ones.end(), 1.0);

// Variables used by the BLAS routines
char trans_N = 'N'; // Do no transpose matrix
char trans_T = 'T'; // Transpose matrix
double alpha = 1; // Scaling on Y*wzu' and wzu*ones
double beta = 0; // Scaling on wzuY and sum_wzu
ptrdiff_t incr = 1; // Vector increment of one

// For loop over time frames
for (ptrdiff_t t=0; t<T; t++) {
// Get the first spike index and spike count for this time frame
ptrdiff_t n1 = ((ptrdiff_t) fsLim[t]) - 1;
ptrdiff_t M = ((ptrdiff_t) fsLim[t+T]) - n1;
if (M <= 0) continue;
// Call dgemm for wzuY = alpha*Y*wzu + beta*wzuY
dgemm(&trans_N, &trans_N, &D, &K, &M,
&alpha, Y+(n1*D), &D, wzu+n1, &N,
&beta, wzuY+(t*D*K), &D);
// Call dgemv for sum_wzu = alpha*wzu'*ones + beta*sum_wzu
dgemv(&trans_T, &M, &K, &alpha, wzu+n1, &N,
ones.data(), &incr, &beta, sumwzu+(t*K), &incr);
// Sum across time frames
switch (numericType) {
case mxDOUBLE_CLASS:
computeFrameSums(D, N, K, T, mxGetPr(mx_Y), mxGetPr(mx_wzu),
fsLim0, mxGetPr(mx_wzuY), mxGetPr(mx_sumwzu) );
break;
case mxSINGLE_CLASS:
computeFrameSums(D, N, K, T, mxGetFloatPr(mx_Y), mxGetFloatPr(mx_wzu),
fsLim0, mxGetFloatPr(mx_wzuY), mxGetFloatPr(mx_sumwzu) );
break;
}

// Output 0 = wzuY
Expand Down
5 changes: 3 additions & 2 deletions @MoDT/sumFrames.m
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
K = size(wzu,2);
T = size(f_spklim,1);
% Alocate memory
wzuY = zeros(D, K, T);
sum_wzu = zeros(K, T);
datatype = class(Y);
wzuY = zeros(D, K, T, datatype);
sum_wzu = zeros(K, T, datatype);

% For loop over time frames
for t = 1:T
Expand Down
Loading

0 comments on commit 103bf8d

Please sign in to comment.