-
Notifications
You must be signed in to change notification settings - Fork 0
/
Engine.hs
36 lines (23 loc) · 1.52 KB
/
Engine.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
module Engine where
import Data
backward :: Node -> Node -- calculates the gradient of each node in the graph
backward (Leaf v g) = Leaf v 1
backward (InnerNode op v g c1 c2) = backward_ $ InnerNode op v 1 c1 c2
backward_ :: Node -> Node
backward_ parent@(InnerNode Pow v g base expo) = InnerNode Pow v g (backward_$ updateGradient base expo parent) (Leaf (value expo) 0) -- gradients do not flow to the exponent for now
backward_ parent@(InnerNode op v g c1 c2) = InnerNode op v g (backward_$ updateGradient c1 c2 parent) (backward_$ updateGradient c2 c1 parent)
backward_ (Leaf v g) = Leaf v g
updateGradient :: Node -> Node -> Node -> Node
updateGradient n@(InnerNode op v g c1 c2) sibling parent = InnerNode op v (computeGradient n sibling parent) c1 c2
updateGradient n@(Leaf v g) sibling parent = Leaf v (computeGradient n sibling parent)
computeGradient :: Node -> Node -> Node -> Gradient
computeGradient n sibling parent@(InnerNode Plus _ g_parent _ _ ) = gradient n + g_parent
computeGradient n sibling parent@(InnerNode Times _ g_parent _ _ ) = gradient n + (g_parent * value sibling)
computeGradient n sibling parent@(InnerNode Pow _ g_parent _ _ ) = gradient n + (g_parent * (value sibling * value n ** (value sibling - 1))) -- gradients do not flow to the exponent for now
computeGradient n _ (Leaf _ _ ) = gradient n
gradient :: Node -> Gradient
gradient (Leaf _ g) = g
gradient (InnerNode _ _ g _ _) = g
getLeafs :: Node -> [Node]
getLeafs (Leaf v g) = [Leaf v g]
getLeafs (InnerNode _ _ _ c1 c2) = getLeafs c1 ++ getLeafs c2