forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathh_sigmoidplugin.h
32 lines (28 loc) · 1.11 KB
/
h_sigmoidplugin.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#ifndef HSIGMOID_PLUGIN_H
#define HSIGMOID_PLUGIN_H
#include <NvInfer.h>
namespace nvinfer1 {
class HSigmoidPlugin : public IPlugin {
public:
HSigmoidPlugin();
HSigmoidPlugin(const void* buffer, size_t size);
~HSigmoidPlugin() override = default;
int getNbOutputs() const override;
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
void configure(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) override;
int initialize() override;
void terminate() override;
size_t getWorkspaceSize(int maxBatchSize) const override;
int enqueue(
int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override;
size_t getSerializationSize() override;
void serialize(void* buffer) override;
private:
int input_size_;
};
class PluginFactory : public IPluginFactory {
public:
IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override;
};
}
#endif