-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmxnet.h
154 lines (149 loc) · 5.86 KB
/
mxnet.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
/*!
* Copyright (c) 2015 by Contributors
* \file c_predict_api.h
* \brief C predict API of mxnet, contains a minimum API to run prediction.
* This file is self-contained, and do not dependent on any other files.
*/
#ifndef MXNET_C_PREDICT_API_H_
#define MXNET_C_PREDICT_API_H_
#ifdef __cplusplus
#define MXNET_EXTERN_C extern "C"
#else
#define MXNET_EXTERN_C
#endif
#ifdef _WIN32
#ifdef MXNET_EXPORTS
#define MXNET_DLL MXNET_EXTERN_C __declspec(dllexport)
#else
#define MXNET_DLL MXNET_EXTERN_C __declspec(dllimport)
#endif
#else
#define MXNET_DLL MXNET_EXTERN_C
#endif
/*! \brief manually define unsigned int */
typedef unsigned int mx_uint;
/*! \brief manually define float */
typedef float mx_float;
/*! \brief handle to Predictor */
typedef void *PredictorHandle;
/*! \brief handle to NDArray list */
typedef void *NDListHandle;
/*!
* \brief Get the last error happeneed.
* \return The last error happened at the predictor.
*/
MXNET_DLL const char* MXGetLastError();
/*!
* \brief create a predictor
* \param symbol_json_str The JSON string of the symbol.
* \param param_bytes The in-memory raw bytes of parameter ndarray file.
* \param param_size The size of parameter ndarray file.
* \param dev_type The device type, 1: cpu, 2:gpu
* \param dev_id The device id of the predictor.
* \param num_input_nodes Number of input nodes to the net,
* For feedforward net, this is 1.
* \param input_keys The name of input argument.
* For feedforward net, this is {"data"}
* \param input_shape_indptr Index pointer of shapes of each input node.
* The length of this array = num_input_nodes + 1.
* For feedforward net that takes 4 dimensional input, this is {0, 4}.
* \param input_shape_data A flatted data of shapes of each input node.
* For feedforward net that takes 4 dimensional input, this is the shape data.
* \param out The created predictor handle.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXPredCreate(const char* symbol_json_str,
const char* param_bytes,
size_t param_size,
int dev_type, int dev_id,
mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
PredictorHandle* out);
/*!
* \brief Get the shape of output node.
* The returned shape_data and shape_ndim is only valid before next call to MXPred function.
* \param handle The handle of the predictor.
* \param index The index of output node, set to 0 if there is only one output.
* \param shape_data Used to hold pointer to the shape data
* \param shape_ndim Used to hold shape dimension.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle,
mx_uint index,
mx_uint** shape_data,
mx_uint* shape_ndim);
/*!
* \brief Set the input data of predictor.
* \param handle The predictor handle.
* \param key The name of input node to set.
* For feedforward net, this is "data".
* \param data The pointer to the data to be set, with the shape specified in MXPredCreate.
* \param size The size of data array, used for safety check.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXPredSetInput(PredictorHandle handle,
const char* key,
const mx_float* data,
mx_uint size);
/*!
* \brief Run a forward pass to get the output
* \param handle The handle of the predictor.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXPredForward(PredictorHandle handle);
/*!
* \brief Get the output value of prediction.
* \param handle The handle of the predictor.
* \param index The index of output node, set to 0 if there is only one output.
* \param data User allocated data to hold the output.
* \param size The size of data array, used for safe checking.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXPredGetOutput(PredictorHandle handle,
mx_uint index,
mx_float* data,
mx_uint size);
/*!
* \brief Free a predictor handle.
* \param handle The handle of the predictor.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXPredFree(PredictorHandle handle);
/*!
* \brief Create a NDArray List by loading from ndarray file.
* This can be used to load mean image file.
* \param nd_file_bytes The byte contents of nd file to be loaded.
* \param nd_file_size The size of the nd file to be loaded.
* \param out The out put NDListHandle
* \param out_length Length of the list.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXNDListCreate(const char* nd_file_bytes,
size_t nd_file_size,
NDListHandle *out,
mx_uint* out_length);
/*!
* \brief Get an element from list
* \param handle The handle to the NDArray
* \param index The index in the list
* \param out_key The output key of the item
* \param out_data The data region of the item
* \param out_shape The shape of the item.
* \param out_ndim The number of dimension in the shape.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXNDListGet(NDListHandle handle,
mx_uint index,
const char** out_key,
const mx_float** out_data,
const mx_uint** out_shape,
mx_uint* out_ndim);
/*!
* \brief Free a predictor handle.
* \param handle The handle of the predictor.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXNDListFree(NDListHandle handle);
#endif // MXNET_C_PREDICT_API_H_