-
Notifications
You must be signed in to change notification settings - Fork 14
/
hadamard.py
140 lines (108 loc) · 4.94 KB
/
hadamard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python
import collections
import cupy
import os
import torch
##########################################################
def cuda_int32(intIn:int):
return cupy.int32(intIn)
# end
def cuda_float32(fltIn:float):
return cupy.float32(fltIn)
# end
@cupy.memoize(for_each_device=True)
def cuda_launch(strFunction:str, strKernel:str):
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
# end
return cupy.RawKernel(strKernel, strFunction)
# end
class hadamard_func(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, tenOne, tenTwo):
tenOne = tenOne.contiguous(); assert(tenOne.is_cuda == True)
tenTwo = tenTwo.contiguous(); assert(tenTwo.is_cuda == True)
tenOut = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]])
if tenOne.is_cuda == True:
cuda_launch('hadamard_out', '''
extern "C" __global__ void __launch_bounds__(512) hadamard_out(
const int n,
const float* __restrict__ tenOne,
const float* __restrict__ tenTwo,
float* __restrict__ tenOut
) {
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
if (intIndex >= n) {
return;
}
tenOut[intIndex] = tenOne[intIndex] * tenTwo[intIndex];
}
''')(
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cupy.int32(tenOut.nelement()), tenOne.data_ptr(), tenTwo.data_ptr(), tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
elif tenOne.is_cuda == False:
raise NotImplementedError()
# end
self.save_for_backward(tenOne, tenTwo)
return tenOut
# end
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def backward(self, tenOutgrad):
tenOne, tenTwo = self.saved_tensors
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]])
tenTwograd = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]])
if tenOne.is_cuda == True:
cuda_launch('hadamard_onegrad', '''
extern "C" __global__ void __launch_bounds__(512) hadamard_onegrad(
const int n,
const float* __restrict__ tenOne,
const float* __restrict__ tenTwo,
const float* __restrict__ tenOutgrad,
float* __restrict__ tenOnegrad,
float* __restrict__ tenTwograd
) {
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
if (intIndex >= n) {
return;
}
tenOnegrad[intIndex] = tenTwo[intIndex] * tenOutgrad[intIndex];
}
''')(
grid=tuple([int((tenOnegrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cupy.int32(tenOnegrad.nelement()), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
cuda_launch('hadamard_twograd', '''
extern "C" __global__ void __launch_bounds__(512) hadamard_twograd(
const int n,
const float* __restrict__ tenOne,
const float* __restrict__ tenTwo,
const float* __restrict__ tenOutgrad,
float* __restrict__ tenOnegrad,
float* __restrict__ tenTwograd
) {
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
if (intIndex >= n) {
return;
}
tenTwograd[intIndex] = tenOne[intIndex] * tenOutgrad[intIndex];
}
''')(
grid=tuple([int((tenTwograd.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cupy.int32(tenTwograd.nelement()), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
elif tenOne.is_cuda == False:
raise NotImplementedError()
# end
return tenOnegrad, tenTwograd
# end
# end