-
Notifications
You must be signed in to change notification settings - Fork 11
/
test_case_cholesky_metal.h
60 lines (43 loc) · 1.69 KB
/
test_case_cholesky_metal.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#ifndef __TEST_CASE_CHOLESKY_METAL_H__
#define __TEST_CASE_CHOLESKY_METAL_H__
#include "cholesky_metal_cpp.h"
#include "test_case_cholesky.h"
template<class T, bool IS_COL_MAJOR>
class TestCaseCholesky_metal : public TestCaseCholesky<T, IS_COL_MAJOR> {
const bool m_use_mps;
CholeskyMetalCpp m_metal;
public:
TestCaseCholesky_metal( const int dim , const bool use_mps )
:TestCaseCholesky<T, IS_COL_MAJOR>( dim )
,m_use_mps(use_mps)
,m_metal( dim, use_mps )
{
static_assert( std::is_same< float,T >::value );
static_assert( IS_COL_MAJOR );
this->setMetal( use_mps ? MPS : DEFAULT, 1, 1 );
}
virtual ~TestCaseCholesky_metal(){ ; }
virtual void setInitialStates( T* L, T* b ) {
m_metal.setInitialStates( L, b );
TestCaseCholesky<T,IS_COL_MAJOR>::setInitialStates( L, b );
}
virtual void compareTruth( const T* const L_baseline, const T* const x_baseline ) {
float* p = m_metal.getRawPointerL();
if ( m_use_mps ) {
for (int i = 0 ; i < this->m_dim ; i++ ) {
for (int j = 0 ; j <= i ; j++ ) {
this->m_L[ lower_mat_index<IS_COL_MAJOR>( i, j, this->m_dim ) ] = p[ this->m_dim * i + j ];
}
}
}
else {
memcpy( this->m_L, m_metal.getRawPointerL(), (this->m_dim + 1) * this->m_dim * sizeof(float) / 2 );
}
memcpy( this->m_x, m_metal.getRawPointerX(), this->m_dim * sizeof(float) );
TestCaseCholesky<T,IS_COL_MAJOR>::compareTruth( L_baseline, x_baseline );
}
void run() {
m_metal.performComputation();
}
};
#endif /*__TEST_CASE_CHOLESKY_METAL_H__*/