-
Notifications
You must be signed in to change notification settings - Fork 0
/
example-14c.cpp
172 lines (116 loc) · 3.7 KB
/
example-14c.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
/*
please note that the series of optmiztion technology is not in official document.
All the tests are based on AMD MI25 radeon instict and AMD ROCm.
*/
#include <assert.h>
#include <stdio.h>
#include <algorithm>
#include <stdlib.h>
#include<iostream>
#include "hip/hip_runtime.h"
#include <math.h>
#define HIP_ASSERT(x) (assert((x)==hipSuccess))
#define N 128
#define C 1024
#define H 28
#define W 28
#define MASK_QUARTER 0x3f
#define MASK_SHIFT 6
#define MASK_256 0xff
#define NUM ( N * C * H * W )
__global__ void
test_kernel(hipLaunchParm lp,
float* __restrict__ bufA, float* __restrict__ bufB, int n, int chw, float gamma, float bata )
{
int x = hipBlockDim_x/4 * hipBlockIdx_x + (hipThreadIdx_x & MASK_QUARTER);
if (hipThreadIdx_x >= 128)
x += (chw * N/4)* (hipThreadIdx_x >> MASK_SHIFT);
float fmean_sum = 0;
float fsquare_sum = 0;
float inputData[N/4];
__shared__ float ldsData0[256];
__shared__ float ldsData1[256];
for (int i = 0; i < N/4; i++)
{
inputData[i] = bufA[x + i * chw];
fmean_sum += inputData[i];
fsquare_sum += inputData[i] * inputData[i];
}
ldsData0[hipThreadIdx_x] = fmean_sum;
ldsData1[hipThreadIdx_x] = fsquare_sum;
__syncthreads();
if (hipThreadIdx_x <= MASK_QUARTER) {
ldsData0[hipThreadIdx_x] += ldsData0[(hipThreadIdx_x + 64)] + ldsData0[(hipThreadIdx_x + 128)] + ldsData0[(hipThreadIdx_x + 192)];
ldsData1[hipThreadIdx_x] += ldsData1[(hipThreadIdx_x + 64)] + ldsData1[(hipThreadIdx_x + 128)] + ldsData1[(hipThreadIdx_x + 192)];
}
__syncthreads();
fmean_sum = ldsData0[hipThreadIdx_x & MASK_QUARTER];
fsquare_sum = ldsData1[hipThreadIdx_x & MASK_QUARTER];
float fmean = fmean_sum / N;
float fstd = fsquare_sum / N - fmean * fmean;
float epsilon = 1e-6;
fstd = rsqrtf(fstd + epsilon);
float result = 0;
for (int i = 0; i < N/4; i++)
{
float v = inputData[i];
result = gamma * (v - fmean)*fstd + bata;
bufB[x + i * chw] = result;
}
}
using namespace std;
int main() {
float* hostA;
float* hostB;
float* deviceA;
float* deviceB;
hipDeviceProp_t devProp;
hipGetDeviceProperties(&devProp, 0);
cout << " System minor " << devProp.minor << endl;
cout << " System major " << devProp.major << endl;
cout << " agent prop name " << devProp.name << endl;
cout << "hip Device prop succeeded " << endl;
hipEvent_t start, stop;
hipEventCreate(&start);
hipEventCreate(&stop);
float eventMs = 1.0f;
int i;
int errors;
hostA = (float*)malloc(NUM * sizeof(float));
hostB = (float*)malloc(NUM * sizeof(float));
float* p;
p = hostA;
for (int i = 0; i < NUM; i++)
{
p[i] = float(sinf(i));
}
HIP_ASSERT(hipMalloc((void**)& deviceA, NUM * sizeof(float)));
HIP_ASSERT(hipMalloc((void**)& deviceB, NUM * sizeof(float)));
HIP_ASSERT(hipMemcpy(deviceA, hostA, NUM * sizeof(float), hipMemcpyHostToDevice));
printf("A\n");
hipLaunchKernel(test_kernel,
dim3(1, 1, 1),
dim3(256, 1, 1),
0, 0,
deviceA, deviceB, 128, 128, 1.0, 0.0);
printf("B\n");
{
hipEventRecord(start, NULL);
hipLaunchKernel(test_kernel,
dim3(C*H*W / 256 *4, 1,1),
dim3(256, 1, 1),
0, 0,
deviceA, deviceB, N, C*H*W, 1.0f, 1.0f);
hipEventRecord(stop, NULL);
hipEventSynchronize(stop);
hipEventElapsedTime(&eventMs, start, stop);
printf("elapsed time:%f\n", eventMs);
double bandwidth = (double)N * (double)C * (double)H * (double)W / (eventMs / 1000.0)/1000/1000/1000;
printf("Estimated Bandwidth %d GPixels/s\n", (int)bandwidth);
}
HIP_ASSERT(hipFree(deviceA));
HIP_ASSERT(hipFree(deviceB));
free(hostA);
free(hostB);
return errors;
}