forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
yololayer.cu
313 lines (268 loc) · 11.6 KB
/
yololayer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
#include <assert.h>
#include <vector>
#include <iostream>
#include "yololayer.h"
#include "cuda_utils.h"
namespace Tn
{
template<typename T>
void write(char*& buffer, const T& val)
{
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
template<typename T>
void read(const char*& buffer, T& val)
{
val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
}
}
using namespace Yolo;
namespace nvinfer1
{
YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel)
{
mClassCount = classCount;
mYoloV5NetWidth = netWidth;
mYoloV5NetHeight = netHeight;
mMaxOutObject = maxOut;
mYoloKernel = vYoloKernel;
mKernelCount = vYoloKernel.size();
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2;
for (int ii = 0; ii < mKernelCount; ii++)
{
CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen));
const auto& yolo = mYoloKernel[ii];
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
}
}
YoloLayerPlugin::~YoloLayerPlugin()
{
for (int ii = 0; ii < mKernelCount; ii++)
{
CUDA_CHECK(cudaFree(mAnchor[ii]));
}
CUDA_CHECK(cudaFreeHost(mAnchor));
}
// create the plugin at runtime from a byte stream
YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length)
{
using namespace Tn;
const char *d = reinterpret_cast<const char *>(data), *a = d;
read(d, mClassCount);
read(d, mThreadCount);
read(d, mKernelCount);
read(d, mYoloV5NetWidth);
read(d, mYoloV5NetHeight);
read(d, mMaxOutObject);
mYoloKernel.resize(mKernelCount);
auto kernelSize = mKernelCount * sizeof(YoloKernel);
memcpy(mYoloKernel.data(), d, kernelSize);
d += kernelSize;
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2;
for (int ii = 0; ii < mKernelCount; ii++)
{
CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen));
const auto& yolo = mYoloKernel[ii];
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
}
assert(d == a + length);
}
void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT
{
using namespace Tn;
char* d = static_cast<char*>(buffer), *a = d;
write(d, mClassCount);
write(d, mThreadCount);
write(d, mKernelCount);
write(d, mYoloV5NetWidth);
write(d, mYoloV5NetHeight);
write(d, mMaxOutObject);
auto kernelSize = mKernelCount * sizeof(YoloKernel);
memcpy(d, mYoloKernel.data(), kernelSize);
d += kernelSize;
assert(d == a + getSerializationSize());
}
size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT
{
return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject);
}
int YoloLayerPlugin::initialize() TRT_NOEXCEPT
{
return 0;
}
Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT
{
//output the result to channel
int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float);
return Dims3(totalsize + 1, 1, 1);
}
// Set plugin namespace
void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT
{
mPluginNamespace = pluginNamespace;
}
const char* YoloLayerPlugin::getPluginNamespace() const TRT_NOEXCEPT
{
return mPluginNamespace;
}
// Return the DataType of the plugin output at the requested index
DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT
{
return DataType::kFLOAT;
}
// Return true if output tensor is broadcast across a batch.
bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT
{
return false;
}
// Return true if plugin can use input that is broadcast across batch without replication.
bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT
{
return false;
}
void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT
{
}
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT
{
}
// Detach the plugin object from its execution context.
void YoloLayerPlugin::detachFromContext() TRT_NOEXCEPT {}
const char* YoloLayerPlugin::getPluginType() const TRT_NOEXCEPT
{
return "YoloLayer_TRT";
}
const char* YoloLayerPlugin::getPluginVersion() const TRT_NOEXCEPT
{
return "1";
}
void YoloLayerPlugin::destroy() TRT_NOEXCEPT
{
delete this;
}
// Clone the plugin
IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT
{
YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel);
p->setPluginNamespace(mPluginNamespace);
return p;
}
__device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); };
__global__ void CalDetection(const float *input, float *output, int noElements,
const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem)
{
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= noElements) return;
int total_grid = yoloWidth * yoloHeight;
int bnIdx = idx / total_grid;
idx = idx - total_grid * bnIdx;
int info_len_i = 5 + classes;
const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT);
for (int k = 0; k < CHECK_COUNT; ++k) {
float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
if (box_prob < IGNORE_THRESH) continue;
int class_id = 0;
float max_cls_prob = 0.0;
for (int i = 5; i < info_len_i; ++i) {
float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);
if (p > max_cls_prob) {
max_cls_prob = p;
class_id = i - 5;
}
}
float *res_count = output + bnIdx * outputElem;
int count = (int)atomicAdd(res_count, 1);
if (count >= maxoutobject) return;
char *data = (char*)res_count + sizeof(float) + count * sizeof(Detection);
Detection *det = (Detection*)(data);
int row = idx / yoloWidth;
int col = idx % yoloWidth;
//Location
// pytorch:
// y = x[i].sigmoid()
// y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
// y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
// X: (sigmoid(tx) + cx)/FeaturemapW * netwidth
det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * netwidth / yoloWidth;
det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * netheight / yoloHeight;
// W: (Pw * e^tw) / FeaturemapW * netwidth
// v5: https://github.com/ultralytics/yolov5/issues/471
det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]);
det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k];
det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]);
det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1];
det->conf = box_prob * max_cls_prob;
det->class_id = class_id;
}
}
void YoloLayerPlugin::forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize)
{
int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);
for (int idx = 0; idx < batchSize; ++idx) {
CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));
}
int numElem = 0;
for (unsigned int i = 0; i < mYoloKernel.size(); ++i) {
const auto& yolo = mYoloKernel[i];
numElem = yolo.width * yolo.height * batchSize;
if (numElem < mThreadCount) mThreadCount = numElem;
//printf("Net: %d %d \n", mYoloV5NetWidth, mYoloV5NetHeight);
CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >
(inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem);
}
}
int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT
{
forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize);
return 0;
}
PluginFieldCollection YoloPluginCreator::mFC{};
std::vector<PluginField> YoloPluginCreator::mPluginAttributes;
YoloPluginCreator::YoloPluginCreator()
{
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* YoloPluginCreator::getPluginName() const TRT_NOEXCEPT
{
return "YoloLayer_TRT";
}
const char* YoloPluginCreator::getPluginVersion() const TRT_NOEXCEPT
{
return "1";
}
const PluginFieldCollection* YoloPluginCreator::getFieldNames() TRT_NOEXCEPT
{
return &mFC;
}
IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT
{
assert(fc->nbFields == 2);
assert(strcmp(fc->fields[0].name, "netinfo") == 0);
assert(strcmp(fc->fields[1].name, "kernels") == 0);
int *p_netinfo = (int*)(fc->fields[0].data);
int class_count = p_netinfo[0];
int input_w = p_netinfo[1];
int input_h = p_netinfo[2];
int max_output_object_count = p_netinfo[3];
std::vector<Yolo::YoloKernel> kernels(fc->fields[1].length);
memcpy(&kernels[0], fc->fields[1].data, kernels.size() * sizeof(Yolo::YoloKernel));
YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, kernels);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT
{
// This object will be deleted when the network is destroyed, which will
// call YoloLayerPlugin::destroy()
YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
}