forked from refraction-networking/utls
-
Notifications
You must be signed in to change notification settings - Fork 1
/
u_session_controller.go
359 lines (320 loc) · 16.8 KB
/
u_session_controller.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
package tls
import (
"errors"
"fmt"
)
// Tracking the state of calling conn.loadSession
type LoadSessionTrackerState int
const NeverCalled LoadSessionTrackerState = 0
const UtlsAboutToCall LoadSessionTrackerState = 1
const CalledByULoadSession LoadSessionTrackerState = 2
const CalledByGoTLS LoadSessionTrackerState = 3
// The state of the session controller
type sessionControllerState int
const NoSession sessionControllerState = 0
const SessionTicketExtInitialized sessionControllerState = 1
const SessionTicketExtAllSet sessionControllerState = 2
const PskExtInitialized sessionControllerState = 3
const PskExtAllSet sessionControllerState = 4
// sessionController is responsible for managing and controlling all session related states. It manages the lifecycle of the session ticket extension and the psk extension, including initialization, removal if the client hello spec doesn't contain any of them, and setting the prepared state to the client hello.
//
// Users should never directly modify the underlying state. Violations will result in undefined behaviors.
//
// Users should never construct sessionController by themselves, use the function `newSessionController` instead.
type sessionController struct {
// sessionTicketExt logically owns the session ticket extension
sessionTicketExt ISessionTicketExtension
// pskExtension logically owns the psk extension
pskExtension PreSharedKeyExtension
// uconnRef is a reference to the uconn
uconnRef *UConn
// state represents the internal state of the sessionController. Users are advised to modify the state only through designated methods and avoid direct manipulation, as doing so may result in undefined behavior.
state sessionControllerState
// loadSessionTracker keeps track of how the conn.loadSession method is being utilized.
loadSessionTracker LoadSessionTrackerState
// callingLoadSession is a boolean flag that indicates whether the `conn.loadSession` function is currently being invoked.
callingLoadSession bool
// locked is a boolean flag that becomes true once all states are appropriately set. Once `locked` is true, further modifications are disallowed, except for the binders.
locked bool
}
// newSessionController constructs a new SessionController
func newSessionController(uconn *UConn) *sessionController {
return &sessionController{
uconnRef: uconn,
sessionTicketExt: nil,
pskExtension: nil,
state: NoSession,
locked: false,
callingLoadSession: false,
loadSessionTracker: NeverCalled,
}
}
func (s *sessionController) isSessionLocked() bool {
return s.locked
}
type shouldLoadSessionResult int
const shouldReturn shouldLoadSessionResult = 0
const shouldSetTicket shouldLoadSessionResult = 1
const shouldSetPsk shouldLoadSessionResult = 2
const shouldLoad shouldLoadSessionResult = 3
// shouldLoadSession determines the appropriate action to take when it is time to load the session for the clientHello.
// There are several possible scenarios:
// - If a session ticket is already initialized, typically via the `initSessionTicketExt()` function, the ticket should be set in the client hello.
// - If a pre-shared key (PSK) is already initialized, typically via the `overridePskExt()` function, the PSK should be set in the client hello.
// - If both the `sessionTicketExt` and `pskExtension` are nil, which might occur if the client hello spec does not include them, we should skip the loadSession().
// - In all other cases, the function proceeds to load the session.
func (s *sessionController) shouldLoadSession() shouldLoadSessionResult {
if s.sessionTicketExt == nil && s.pskExtension == nil || s.uconnRef.clientHelloBuildStatus != NotBuilt {
// No need to load session since we don't have the related extensions.
return shouldReturn
}
if s.state == SessionTicketExtInitialized {
return shouldSetTicket
}
if s.state == PskExtInitialized {
return shouldSetPsk
}
return shouldLoad
}
// utlsAboutToLoadSession updates the loadSessionTracker to `UtlsAboutToCall` to signal the initiation of a session loading operation,
// provided that the preconditions are met. If the preconditions are not met (due to incorrect utls implementation), this function triggers a panic.
func (s *sessionController) utlsAboutToLoadSession() {
uAssert(s.state == NoSession && !s.locked, "tls: aboutToLoadSession failed: must only load session when the session of the client hello is not locked and when there's currently no session")
s.loadSessionTracker = UtlsAboutToCall
}
func (s *sessionController) assertHelloNotBuilt(caller string) {
if s.uconnRef.clientHelloBuildStatus != NotBuilt {
panic(fmt.Sprintf("tls: %s failed: we can't modify the session after the clientHello is built", caller))
}
}
func (s *sessionController) assertControllerState(caller string, desired sessionControllerState, moreDesiredStates ...sessionControllerState) {
if s.state != desired && !anyTrue(moreDesiredStates, func(_ int, state *sessionControllerState) bool {
return s.state == *state
}) {
panic(fmt.Sprintf("tls: %s failed: undesired controller state %d", caller, s.state))
}
}
func (s *sessionController) assertNotLocked(caller string) {
if s.locked {
panic(fmt.Sprintf("tls: %s failed: you must not modify the session after it's locked", caller))
}
}
func (s *sessionController) assertCanSkip(caller, extensionName string) {
if !s.uconnRef.skipResumptionOnNilExtension {
panic(fmt.Sprintf("tls: %s failed: session resumption is enabled, but there is no %s in the ClientHelloSpec; Please consider provide one in the ClientHelloSpec; If this is intentional, you may consider disable resumption by setting Config.SessionTicketsDisabled to true, or set Config.PreferSkipResumptionOnNilExtension to true to suppress this exception", caller, extensionName))
}
}
// finalCheck performs a comprehensive check on the updated state to ensure the correctness of the changes.
// If the checks pass successfully, the sessionController's state will be locked.
// Any failure in passing the tests indicates incorrect implementations in the utls, which will result in triggering a panic.
// Refer to the documentation for the `locked` field for more detailed information.
func (s *sessionController) finalCheck() {
s.assertControllerState("SessionController.finalCheck", PskExtAllSet, SessionTicketExtAllSet, NoSession)
s.locked = true
}
func initializationGuard[E Initializable, I func(E)](extension E, initializer I) {
uAssert(!extension.IsInitialized(), "tls: initialization failed: the extension is already initialized")
initializer(extension)
uAssert(extension.IsInitialized(), "tls: initialization failed: the extension is not initialized after initialization")
}
// initSessionTicketExt initializes the ticket and sets the state to `TicketInitialized`.
func (s *sessionController) initSessionTicketExt(session *SessionState, ticket []byte) {
s.assertNotLocked("initSessionTicketExt")
s.assertHelloNotBuilt("initSessionTicketExt")
s.assertControllerState("initSessionTicketExt", NoSession)
panicOnNil("initSessionTicketExt", session, ticket)
if s.sessionTicketExt == nil {
s.assertCanSkip("initSessionTicketExt", "session ticket extension")
return
}
initializationGuard(s.sessionTicketExt, func(e ISessionTicketExtension) {
s.sessionTicketExt.InitializeByUtls(session, ticket)
})
s.state = SessionTicketExtInitialized
}
// initPSK initializes the PSK extension using a valid session. The PSK extension
// should not be initialized previously, and the parameters must not be nil;
// otherwise, this function will trigger a panic.
func (s *sessionController) initPskExt(session *SessionState, earlySecret []byte, binderKey []byte, pskIdentities []pskIdentity) {
s.assertNotLocked("initPskExt")
s.assertHelloNotBuilt("initPskExt")
s.assertControllerState("initPskExt", NoSession)
panicOnNil("initPskExt", session, earlySecret, pskIdentities)
if s.pskExtension == nil {
s.assertCanSkip("initPskExt", "pre-shared key extension")
return
}
initializationGuard(s.pskExtension, func(e PreSharedKeyExtension) {
publicPskIdentities := mapSlice(pskIdentities, func(private pskIdentity) PskIdentity {
return PskIdentity{
Label: private.label,
ObfuscatedTicketAge: private.obfuscatedTicketAge,
}
})
e.InitializeByUtls(session, earlySecret, binderKey, publicPskIdentities)
})
s.state = PskExtInitialized
}
// setSessionTicketToUConn write the ticket states from the session ticket extension to the client hello and handshake state.
func (s *sessionController) setSessionTicketToUConn() {
uAssert(s.sessionTicketExt != nil && s.state == SessionTicketExtInitialized, "tls: setSessionTicketExt failed: invalid state")
s.uconnRef.HandshakeState.Session = s.sessionTicketExt.GetSession()
s.uconnRef.HandshakeState.Hello.SessionTicket = s.sessionTicketExt.GetTicket()
s.state = SessionTicketExtAllSet
}
// setPskToUConn sets the psk to the handshake state and client hello.
func (s *sessionController) setPskToUConn() {
uAssert(s.pskExtension != nil && (s.state == PskExtInitialized || s.state == PskExtAllSet), "tls: setPskToUConn failed: invalid state")
pskCommon := s.pskExtension.GetPreSharedKeyCommon()
if s.state == PskExtInitialized {
s.uconnRef.HandshakeState.State13.EarlySecret = pskCommon.EarlySecret
s.uconnRef.HandshakeState.Session = pskCommon.Session
s.uconnRef.HandshakeState.Hello.PskIdentities = pskCommon.Identities
s.uconnRef.HandshakeState.Hello.PskBinders = pskCommon.Binders
} else if s.state == PskExtAllSet {
uAssert(s.uconnRef.HandshakeState.Session == pskCommon.Session && sliceEq(s.uconnRef.HandshakeState.State13.EarlySecret, pskCommon.EarlySecret) &&
allTrue(s.uconnRef.HandshakeState.Hello.PskIdentities, func(i int, psk *PskIdentity) bool {
return pskCommon.Identities[i].ObfuscatedTicketAge == psk.ObfuscatedTicketAge && sliceEq(pskCommon.Identities[i].Label, psk.Label)
}), "tls: setPskToUConn failed: only binders are allowed to change on state `PskAllSet`")
}
s.uconnRef.HandshakeState.State13.BinderKey = pskCommon.BinderKey
s.state = PskExtAllSet
}
// shouldUpdateBinders determines whether binders should be updated based on the presence of an initialized psk extension.
// This function returns true if an initialized psk extension exists. Binders are allowed to be updated when the state is `PskAllSet`,
// as the `BuildHandshakeState` function can be called multiple times in this case. However, it's important to note that
// the session state, apart from binders, should not be altered more than once.
func (s *sessionController) shouldUpdateBinders() bool {
if s.pskExtension == nil {
return false
}
return (s.state == PskExtInitialized || s.state == PskExtAllSet)
}
func (s *sessionController) updateBinders() {
uAssert(s.shouldUpdateBinders(), "tls: updateBinders failed: shouldn't update binders")
s.pskExtension.PatchBuiltHello(s.uconnRef.HandshakeState.Hello) // bugrisk: retured error is ignored
}
func (s *sessionController) overrideExtension(extension Initializable, override func(), initializedState sessionControllerState) error {
panicOnNil("overrideExtension", extension)
s.assertNotLocked("overrideExtension")
s.assertControllerState("overrideExtension", NoSession)
override()
if extension.IsInitialized() {
s.state = initializedState
}
return nil
}
// overridePskExt allows the user of utls to customize the psk extension.
func (s *sessionController) overridePskExt(pskExt PreSharedKeyExtension) error {
return s.overrideExtension(pskExt, func() { s.pskExtension = pskExt }, PskExtInitialized)
}
// overridePskExt allows the user of utls to customize the session ticket extension.
func (s *sessionController) overrideSessionTicketExt(sessionTicketExt ISessionTicketExtension) error {
return s.overrideExtension(sessionTicketExt, func() { s.sessionTicketExt = sessionTicketExt }, SessionTicketExtInitialized)
}
// syncSessionExts synchronizes the sessionController with the session-related
// extensions from the extension list after applying client hello specs.
//
// - If the extension list is missing the session ticket extension or PSK
// extension, owned extensions are dropped and states are reset.
// - If the user provides a session ticket extension or PSK extension, the
// corresponding extension from the extension list will be replaced.
// - If the user doesn't provide session-related extensions, the extensions
// from the extension list will be utilized.
//
// This function ensures that there is only one session ticket extension or PSK
// extension, and that the PSK extension is the last extension in the extension
// list.
func (s *sessionController) syncSessionExts() error {
uAssert(s.uconnRef.clientHelloBuildStatus == NotBuilt, "tls: checkSessionExts failed: we can't modify the session after the clientHello is built")
s.assertNotLocked("checkSessionExts")
s.assertHelloNotBuilt("checkSessionExts")
s.assertControllerState("checkSessionExts", NoSession, SessionTicketExtInitialized, PskExtInitialized)
numSessionExt := 0
hasPskExt := false
for i, e := range s.uconnRef.Extensions {
switch ext := e.(type) {
case ISessionTicketExtension:
uAssert(numSessionExt == 0, "tls: checkSessionExts failed: multiple ISessionTicketExtensions in the extension list")
if s.sessionTicketExt == nil {
// If there isn't a user-provided session ticket extension, use the one from the spec
s.sessionTicketExt = ext
} else {
// Otherwise, replace the one in the extension list with the user-provided one
s.uconnRef.Extensions[i] = s.sessionTicketExt
}
numSessionExt += 1
case PreSharedKeyExtension:
uAssert(i == len(s.uconnRef.Extensions)-1, "tls: checkSessionExts failed: PreSharedKeyExtension must be the last extension")
if s.pskExtension == nil {
// If there isn't a user-provided psk extension, use the one from the spec
s.pskExtension = ext
} else {
// Otherwise, replace the one in the extension list with the user-provided one
s.uconnRef.Extensions[i] = s.pskExtension
}
s.pskExtension.SetOmitEmptyPsk(s.uconnRef.config.OmitEmptyPsk)
hasPskExt = true
}
}
if numSessionExt == 0 {
if s.state == SessionTicketExtInitialized {
return errors.New("tls: checkSessionExts failed: the user provided a session ticket, but the specification doesn't contain one")
}
s.sessionTicketExt = nil
s.uconnRef.HandshakeState.Session = nil
s.uconnRef.HandshakeState.Hello.SessionTicket = nil
}
if !hasPskExt {
if s.state == PskExtInitialized {
return errors.New("tls: checkSessionExts failed: the user provided a psk, but the specification doesn't contain one")
}
s.pskExtension = nil
s.uconnRef.HandshakeState.State13.BinderKey = nil
s.uconnRef.HandshakeState.State13.EarlySecret = nil
s.uconnRef.HandshakeState.Session = nil
s.uconnRef.HandshakeState.Hello.PskIdentities = nil
}
return nil
}
// onEnterLoadSessionCheck is intended to be invoked upon entering the `conn.loadSession` function.
// It is designed to ensure the correctness of the utls implementation. If the utls implementation is found to be incorrect, this function will trigger a panic.
func (s *sessionController) onEnterLoadSessionCheck() {
uAssert(!s.locked, "tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: session is set and locked, no call to loadSession is allowed")
switch s.loadSessionTracker {
case UtlsAboutToCall, NeverCalled:
s.callingLoadSession = true
case CalledByULoadSession, CalledByGoTLS:
panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: you must not call loadSession() twice")
default:
panic("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: unimplemented state")
}
}
// onLoadSessionReturn is intended to be invoked upon returning from the `conn.loadSession` function.
// It serves as a validation step for the correctness of the underlying utls implementation.
// If the utls implementation is incorrect, this function will trigger a panic.
func (s *sessionController) onLoadSessionReturn() {
uAssert(s.callingLoadSession, "tls: LoadSessionCoordinator.onLoadSessionReturn failed: it's not loading sessions, perhaps this function is not being called by loadSession.")
switch s.loadSessionTracker {
case NeverCalled:
s.loadSessionTracker = CalledByGoTLS
case UtlsAboutToCall:
s.loadSessionTracker = CalledByULoadSession
default:
panic("tls: LoadSessionCoordinator.onLoadSessionReturn failed: unimplemented state")
}
s.callingLoadSession = false
}
// shouldLoadSessionWriteBinders checks if `conn.loadSession` should proceed to write binders and marshal the client hello. If the utls implementation
// is incorrect, this function will trigger a panic.
func (s *sessionController) shouldLoadSessionWriteBinders() bool {
uAssert(s.callingLoadSession, "tls: shouldWriteBinders failed: LoadSessionCoordinator isn't loading sessions, perhaps this function is not being called by loadSession.")
switch s.loadSessionTracker {
case NeverCalled:
return true
case UtlsAboutToCall:
return false
default:
panic("tls: shouldWriteBinders failed: unimplemented state")
}
}