-
Notifications
You must be signed in to change notification settings - Fork 1
/
frnn.hpp
38 lines (30 loc) · 971 Bytes
/
frnn.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#ifndef FRNN_HPP
#define FRNN_HPP
#include "fructose.hpp"
/**
* Fructose-Neural-Networks (fr::nn), additional functions for implementing neural nets.
*/
namespace fr::nn {
// Ops
Tensor relu(const Tensor& tensor);
Tensor softmaxCrossEntropy(const Tensor& logits, const Tensor& labels);
Tensor dropout(const Tensor& a, float dropProbability);
// Optimisers
void sgd(const Tensor& tensor, const Tensor& learningRate);
struct AdamParams {
float betaM;
float betaV;
float epsilon;
float weightDecay;
};
Tensor adamStepSizeAutoIncrement(const Tensor& step,
const Tensor& learningRate,
const AdamParams& params);
void adam(const Tensor& tensor,
const Tensor& momentum,
const Tensor& variance,
const Tensor& stepSize,
const AdamParams& params);
} // namespace fr::nn
#endif // FRNN_HPP