-
Notifications
You must be signed in to change notification settings - Fork 5
/
approx_kfn.go
228 lines (190 loc) · 7.67 KB
/
approx_kfn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
package mlpack
/*
#cgo CFLAGS: -I./capi -Wall
#cgo LDFLAGS: -L. -lmlpack_go_approx_kfn
#include <capi/approx_kfn.h>
#include <stdlib.h>
*/
import "C"
import "gonum.org/v1/gonum/mat"
type ApproxKfnOptionalParam struct {
Algorithm string
CalculateError bool
ExactDistances *mat.Dense
InputModel *approxkfnModel
K int
NumProjections int
NumTables int
Query *mat.Dense
Reference *mat.Dense
Verbose bool
}
func ApproxKfnOptions() *ApproxKfnOptionalParam {
return &ApproxKfnOptionalParam{
Algorithm: "ds",
CalculateError: false,
ExactDistances: nil,
InputModel: nil,
K: 0,
NumProjections: 5,
NumTables: 5,
Query: nil,
Reference: nil,
Verbose: false,
}
}
/*
This program implements two strategies for furthest neighbor search. These
strategies are:
- The 'qdafn' algorithm from "Approximate Furthest Neighbor in High
Dimensions" by R. Pagh, F. Silvestri, J. Sivertsen, and M. Skala, in
Similarity Search and Applications 2015 (SISAP).
- The 'DrusillaSelect' algorithm from "Fast approximate furthest neighbors
with data-dependent candidate selection", by R.R. Curtin and A.B. Gardner, in
Similarity Search and Applications 2016 (SISAP).
These two strategies give approximate results for the furthest neighbor search
problem and can be used as fast replacements for other furthest neighbor
techniques such as those found in the mlpack_kfn program. Note that
typically, the 'ds' algorithm requires far fewer tables and projections than
the 'qdafn' algorithm.
Specify a reference set (set to search in) with "Reference", specify a query
set with "Query", and specify algorithm parameters with "NumTables" and
"NumProjections" (or don't and defaults will be used). The algorithm to be
used (either 'ds'---the default---or 'qdafn') may be specified with
"Algorithm". Also specify the number of neighbors to search for with "K".
Note that for 'qdafn' in lower dimensions, "NumProjections" may need to be set
to a high value in order to return results for each query point.
If no query set is specified, the reference set will be used as the query set.
The "OutputModel" output parameter may be used to store the built model, and
an input model may be loaded instead of specifying a reference set with the
"InputModel" option.
Results for each query point can be stored with the "Neighbors" and
"Distances" output parameters. Each row of these output matrices holds the k
distances or neighbor indices for each query point.
For example, to find the 5 approximate furthest neighbors with reference_set
as the reference set and query_set as the query set using DrusillaSelect,
storing the furthest neighbor indices to neighbors and the furthest neighbor
distances to distances, one could call
// Initialize optional parameters for ApproxKfn().
param := mlpack.ApproxKfnOptions()
param.Query = query_set
param.Reference = reference_set
param.K = 5
param.Algorithm = "ds"
distances, neighbors, _ := mlpack.ApproxKfn(param)
and to perform approximate all-furthest-neighbors search with k=1 on the set
data storing only the furthest neighbor distances to distances, one could call
// Initialize optional parameters for ApproxKfn().
param := mlpack.ApproxKfnOptions()
param.Reference = reference_set
param.K = 1
distances, _, _ := mlpack.ApproxKfn(param)
A trained model can be re-used. If a model has been previously saved to
model, then we may find 3 approximate furthest neighbors on a query set
new_query_set using that model and store the furthest neighbor indices into
neighbors by calling
// Initialize optional parameters for ApproxKfn().
param := mlpack.ApproxKfnOptions()
param.InputModel = &model
param.Query = new_query_set
param.K = 3
_, neighbors, _ := mlpack.ApproxKfn(param)
Input parameters:
- Algorithm (string): Algorithm to use: 'ds' or 'qdafn'. Default value
'ds'.
- CalculateError (bool): If set, calculate the average distance error
for the first furthest neighbor only.
- ExactDistances (mat.Dense): Matrix containing exact distances to
furthest neighbors; this can be used to avoid explicit calculation when
--calculate_error is set.
- InputModel (approxkfnModel): File containing input model.
- K (int): Number of furthest neighbors to search for. Default value
0.
- NumProjections (int): Number of projections to use in each hash
table. Default value 5.
- NumTables (int): Number of hash tables to use. Default value 5.
- Query (mat.Dense): Matrix containing query points.
- Reference (mat.Dense): Matrix containing the reference dataset.
- Verbose (bool): Display informational messages and the full list of
parameters and timers at the end of execution.
Output parameters:
- distances (mat.Dense): Matrix to save furthest neighbor distances
to.
- neighbors (mat.Dense): Matrix to save neighbor indices to.
- outputModel (approxkfnModel): File to save output model to.
*/
func ApproxKfn(param *ApproxKfnOptionalParam) (*mat.Dense, *mat.Dense, approxkfnModel) {
params := getParams("approx_kfn")
timers := getTimers()
disableBacktrace()
disableVerbose()
// Detect if the parameter was passed; set if so.
if param.Algorithm != "ds" {
setParamString(params, "algorithm", param.Algorithm)
setPassed(params, "algorithm")
}
// Detect if the parameter was passed; set if so.
if param.CalculateError != false {
setParamBool(params, "calculate_error", param.CalculateError)
setPassed(params, "calculate_error")
}
// Detect if the parameter was passed; set if so.
if param.ExactDistances != nil {
gonumToArmaMat(params, "exact_distances", param.ExactDistances, false)
setPassed(params, "exact_distances")
}
// Detect if the parameter was passed; set if so.
if param.InputModel != nil {
setApproxKFNModel(params, "input_model", param.InputModel)
setPassed(params, "input_model")
}
// Detect if the parameter was passed; set if so.
if param.K != 0 {
setParamInt(params, "k", param.K)
setPassed(params, "k")
}
// Detect if the parameter was passed; set if so.
if param.NumProjections != 5 {
setParamInt(params, "num_projections", param.NumProjections)
setPassed(params, "num_projections")
}
// Detect if the parameter was passed; set if so.
if param.NumTables != 5 {
setParamInt(params, "num_tables", param.NumTables)
setPassed(params, "num_tables")
}
// Detect if the parameter was passed; set if so.
if param.Query != nil {
gonumToArmaMat(params, "query", param.Query, false)
setPassed(params, "query")
}
// Detect if the parameter was passed; set if so.
if param.Reference != nil {
gonumToArmaMat(params, "reference", param.Reference, false)
setPassed(params, "reference")
}
// Detect if the parameter was passed; set if so.
if param.Verbose != false {
setParamBool(params, "verbose", param.Verbose)
setPassed(params, "verbose")
enableVerbose()
}
// Mark all output options as passed.
setPassed(params, "distances")
setPassed(params, "neighbors")
setPassed(params, "output_model")
// Call the mlpack program.
C.mlpackApproxKfn(params.mem, timers.mem)
// Initialize result variable and get output.
var distancesPtr mlpackArma
distances := distancesPtr.armaToGonumMat(params, "distances")
var neighborsPtr mlpackArma
neighbors := neighborsPtr.armaToGonumUmat(params, "neighbors")
var outputModel approxkfnModel
outputModel.getApproxKFNModel(params, "output_model")
// Clean memory.
cleanParams(params)
cleanTimers(timers)
// Return output(s).
return distances, neighbors, outputModel
}