-
Notifications
You must be signed in to change notification settings - Fork 5
/
gmm_probability.go
85 lines (66 loc) · 2.4 KB
/
gmm_probability.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
package mlpack
/*
#cgo CFLAGS: -I./capi -Wall
#cgo LDFLAGS: -L. -lmlpack_go_gmm_probability
#include <capi/gmm_probability.h>
#include <stdlib.h>
*/
import "C"
import "gonum.org/v1/gonum/mat"
type GmmProbabilityOptionalParam struct {
Verbose bool
}
func GmmProbabilityOptions() *GmmProbabilityOptionalParam {
return &GmmProbabilityOptionalParam{
Verbose: false,
}
}
/*
This program calculates the probability that given points came from a given
GMM (that is, P(X | gmm)). The GMM is specified with the "InputModel"
parameter, and the points are specified with the "Input" parameter. The
output probabilities may be saved via the "Output" output parameter.
So, for example, to calculate the probabilities of each point in points coming
from the pre-trained GMM gmm, while storing those probabilities in probs, the
following command could be used:
// Initialize optional parameters for GmmProbability().
param := mlpack.GmmProbabilityOptions()
probs := mlpack.GmmProbability(&gmm, points, param)
Input parameters:
- input (mat.Dense): Input matrix to calculate probabilities of.
- inputModel (gmm): Input GMM to use as model.
- Verbose (bool): Display informational messages and the full list of
parameters and timers at the end of execution.
Output parameters:
- output (mat.Dense): Matrix to store calculated probabilities in.
*/
func GmmProbability(input *mat.Dense, inputModel *gmm, param *GmmProbabilityOptionalParam) (*mat.Dense) {
params := getParams("gmm_probability")
timers := getTimers()
disableBacktrace()
disableVerbose()
// Detect if the parameter was passed; set if so.
gonumToArmaMat(params, "input", input, false)
setPassed(params, "input")
// Detect if the parameter was passed; set if so.
setGMM(params, "input_model", inputModel)
setPassed(params, "input_model")
// Detect if the parameter was passed; set if so.
if param.Verbose != false {
setParamBool(params, "verbose", param.Verbose)
setPassed(params, "verbose")
enableVerbose()
}
// Mark all output options as passed.
setPassed(params, "output")
// Call the mlpack program.
C.mlpackGmmProbability(params.mem, timers.mem)
// Initialize result variable and get output.
var outputPtr mlpackArma
output := outputPtr.armaToGonumMat(params, "output")
// Clean memory.
cleanParams(params)
cleanTimers(timers)
// Return output(s).
return output
}