-
Notifications
You must be signed in to change notification settings - Fork 11
/
gauss_seidel_solver.metal
110 lines (67 loc) · 3.68 KB
/
gauss_seidel_solver.metal
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include <metal_stdlib>
using namespace metal;
void atomic_add_float( device atomic_uint* atom_var, const float val )
{
uint fetched_uint, assigning_uint;
float fetched_float, assigning_float;
fetched_uint = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed );
fetched_float = *( (thread float*) &fetched_uint );
assigning_float = fetched_float + val;
assigning_uint = *( (thread uint*) &assigning_float );
while ( (fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint, memory_order_relaxed ) ) != 0 ) {
uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed );
float fetched_float_again = *( (thread float*) &fetched_uint_again );
fetched_float = *( (thread float*) &(fetched_uint) );
assigning_float = fetched_float_again + fetched_float;
assigning_uint = *( (thread uint*) &assigning_float );
}
}
struct gauss_seidel_solver_constants {
int dim;
};
kernel void solve_raw_major (
device const float* A [[ buffer(0) ]],
device const float* Dinv [[ buffer(1) ]],
device const float* b [[ buffer(2) ]],
device const float* xin [[ buffer(3) ]],
device float* xout [[ buffer(4) ]],
device atomic_uint* x_error [[ buffer(5) ]],
device const gauss_seidel_solver_constants& constants [[ buffer(6) ]],
const uint thread_position_in_threadgroup [[ thread_position_in_threadgroup ]],
const uint threadgroup_position_in_grid [[ threadgroup_position_in_grid ]],
const uint thread_position_in_grid [[ thread_position_in_grid ]],
const uint threads_per_threadgroup [[ threads_per_threadgroup ]],
const uint thread_index_in_simdgroup [[ thread_index_in_simdgroup ]],
const uint simdgroup_index_in_threadgroup [[ simdgroup_index_in_threadgroup ]],
const uint simdgroups_per_threadgroup [[ simdgroups_per_threadgroup ]]
) {
const int THREADS_PER_THREADGROUP = 1024; // macos
// 1st step: xout = A*xi
threadgroup float sum_cache[ THREADS_PER_THREADGROUP ];
for ( int row = 0; row < constants.dim; row++ ) {
float sum = 0.0;
for ( int col = thread_position_in_threadgroup ; col < constants.dim ; col += threads_per_threadgroup ) {
if ( col < row ) {
sum += ( A[ row * constants.dim + col ] * xout[ col ] );
}
else if ( col > row ) {
sum += ( A[ row * constants.dim + col ] * xin[ col ] );
}
}
const float warp_sum = simd_sum (sum);
if ( thread_index_in_simdgroup == 0 ){
sum_cache[ simdgroup_index_in_threadgroup ] = warp_sum;
}
threadgroup_barrier( mem_flags::mem_threadgroup );
if ( simdgroup_index_in_threadgroup == 0 ) {
const float local_sum = (thread_index_in_simdgroup < simdgroups_per_threadgroup)
? sum_cache[ thread_index_in_simdgroup ]
: 0.0;
const float warp_sum = simd_sum( local_sum );
if ( thread_position_in_threadgroup == 0 ) {
xout[row] = (b[row] - warp_sum)*Dinv[row];
atomic_add_float( x_error, (xout[row] - xin[row])*(xout[row] - xin[row]) );
}
}
}
}