-
Notifications
You must be signed in to change notification settings - Fork 4
/
channelmix.go
83 lines (66 loc) · 2.37 KB
/
channelmix.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// Copyright 2023 NLP Odyssey Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rwkv
import (
"encoding/gob"
"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/mat/float"
"github.com/nlpodyssey/spago/nn"
)
var _ nn.Model = &TimeMix{}
// ChannelMix implements the channel mix module.
type ChannelMix struct {
nn.Module
Key nn.Param `spago:"type:weights"`
Value nn.Param `spago:"type:weights"`
Receptance nn.Param `spago:"type:weights"`
TimeMixK nn.Param `spago:"type:weights"`
TimeMixR nn.Param `spago:"type:weights"`
}
func init() {
gob.Register(&ChannelMix{})
}
func NewChannelMix[T float.DType](c Config, _ int) *ChannelMix {
hidden := 4 * c.DModel
return &ChannelMix{
Key: nn.NewParam(mat.NewEmptyDense[T](hidden, c.DModel)),
Value: nn.NewParam(mat.NewEmptyDense[T](c.DModel, hidden)),
Receptance: nn.NewParam(mat.NewEmptyDense[T](c.DModel, c.DModel)),
TimeMixK: nn.NewParam(mat.NewEmptyVecDense[T](c.DModel)),
TimeMixR: nn.NewParam(mat.NewEmptyVecDense[T](c.DModel)),
}
}
// ForwardSingle performs the forward step for a single node.
func (m *ChannelMix) ForwardSingle(x ag.Node, state *LayerState) (rkv ag.Node) {
xx := state.FfnXX
xk := ag.Add(ag.Prod(x, m.TimeMixK), ag.Prod(ag.ReverseSub(m.TimeMixK, one), xx))
xr := ag.Add(ag.Prod(x, m.TimeMixR), ag.Prod(ag.ReverseSub(m.TimeMixR, one), xx))
state.FfnXX = x
k := ag.Mul(m.Key, xk)
k = ag.Square(ag.ReLU(k))
kv := ag.Mul(m.Value, k)
rkv = ag.Prod(ag.Sigmoid(ag.Mul(m.Receptance, xr)), kv)
return
}
// ForwardSequence performs the forward step for a sequence of nodes.
// The state is updated with the last node of the sequence.
func (m *ChannelMix) ForwardSequence(x []ag.Node, state *LayerState) (rkv []ag.Node) {
rkv = make([]ag.Node, len(x))
// token shift
xx := append([]ag.Node{state.FfnXX}, x[:len(x)-1]...)
// precompute coefficients
tmk := ag.ReverseSub(m.TimeMixK, one)
tmr := ag.ReverseSub(m.TimeMixR, one)
for i, xi := range x {
xk := ag.Add(ag.Prod(xi, m.TimeMixK), ag.Prod(tmk, xx[i]))
xr := ag.Add(ag.Prod(xi, m.TimeMixR), ag.Prod(tmr, xx[i]))
k := ag.Mul(m.Key, xk)
k = ag.Square(ag.ReLU(k))
kv := ag.Mul(m.Value, k)
rkv[i] = ag.Prod(ag.Sigmoid(ag.Mul(m.Receptance, xr)), kv)
}
state.FfnXX = x[len(x)-1]
return
}