diff --git a/client.go b/client.go index 5ae17ce..ef93a6f 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package webtransport import ( "context" + "crypto/tls" "errors" "fmt" "net/http" @@ -17,12 +18,12 @@ import ( var errNoWebTransport = errors.New("server didn't enable WebTransport") type Dialer struct { - // If not set, reasonable defaults will be used. - // In order for WebTransport to function, this implementation will: - // * overwrite the StreamHijacker and UniStreamHijacker - // * enable datagram support - // * set the MaxIncomingStreams to 100 on the quic.Config, if unset - *http3.RoundTripper + // TLSClientConfig is the TLS client config used when dialing the QUIC connection. + // It must set the h3 ALPN. + TLSClientConfig *tls.Config + + // QUICConfig is the QUIC config used when dialing the QUIC connection. + QUICConfig *quic.Config // StreamReorderingTime is the time an incoming WebTransport stream that cannot be associated // with a session is buffered. @@ -31,6 +32,10 @@ type Dialer struct { // Defaults to 5 seconds. StreamReorderingTimeout time.Duration + // DialAddr is the function used to dial the underlying QUIC connection. + // If unset, quic.DialAddrEarly will be used. + DialAddr func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + ctx context.Context ctxCancel context.CancelFunc @@ -46,44 +51,6 @@ func (d *Dialer) init() { } d.conns = *newSessionManager(timeout) d.ctx, d.ctxCancel = context.WithCancel(context.Background()) - if d.RoundTripper == nil { - d.RoundTripper = &http3.RoundTripper{} - } - d.RoundTripper.EnableDatagrams = true - if d.RoundTripper.AdditionalSettings == nil { - d.RoundTripper.AdditionalSettings = make(map[uint64]uint64) - } - d.RoundTripper.AdditionalSettings[settingsEnableWebtransport] = 1 - d.RoundTripper.StreamHijacker = func(ft http3.FrameType, conn quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { - if isWebTransportError(e) { - return true, nil - } - if ft != webTransportFrameType { - return false, nil - } - id, err := quicvarint.Read(quicvarint.NewReader(str)) - if err != nil { - if isWebTransportError(err) { - return true, nil - } - return false, err - } - d.conns.AddStream(conn, str, sessionID(id)) - return true, nil - } - d.RoundTripper.UniStreamHijacker = func(st http3.StreamType, conn quic.Connection, str quic.ReceiveStream, err error) (hijacked bool) { - if st != webTransportUniStreamType && !isWebTransportError(err) { - return false - } - d.conns.AddUniStream(conn, str) - return true - } - if d.QuicConfig == nil { - d.QuicConfig = &quic.Config{EnableDatagrams: true} - } - if d.QuicConfig.MaxIncomingStreams == 0 { - d.QuicConfig.MaxIncomingStreams = 100 - } } func (d *Dialer) Dial(ctx context.Context, urlStr string, reqHdr http.Header) (*http.Response, *Session, error) { @@ -91,8 +58,21 @@ func (d *Dialer) Dial(ctx context.Context, urlStr string, reqHdr http.Header) (* // Technically, this is not true. DATAGRAMs could be sent using the Capsule protocol. // However, quic-go currently enforces QUIC datagram support if HTTP/3 datagrams are enabled. - if !d.QuicConfig.EnableDatagrams { - return nil, nil, errors.New("WebTransport requires DATAGRAM support, enable it via QuicConfig.EnableDatagrams") + quicConf := d.QUICConfig + if quicConf == nil { + quicConf = &quic.Config{EnableDatagrams: true} + } else if !d.QUICConfig.EnableDatagrams { + return nil, nil, errors.New("WebTransport requires DATAGRAM support, enable it via QUICConfig.EnableDatagrams") + } + + tlsConf := d.TLSClientConfig + if tlsConf == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() + } + if len(tlsConf.NextProtos) == 0 { + tlsConf.NextProtos = []string{http3.NextProtoH3} } u, err := url.Parse(urlStr) @@ -112,38 +92,74 @@ func (d *Dialer) Dial(ctx context.Context, urlStr string, reqHdr http.Header) (* } req = req.WithContext(ctx) - rsp, err := d.RoundTripper.RoundTripOpt(req, http3.RoundTripOpt{ - DontCloseRequestStream: true, - CheckSettings: func(settings http3.Settings) error { - if !settings.EnableExtendedConnect { - return errors.New("server didn't enable Extended CONNECT") + dialAddr := d.DialAddr + if dialAddr == nil { + dialAddr = quic.DialAddrEarly + } + qconn, err := dialAddr(ctx, u.Host, tlsConf, quicConf) + if err != nil { + return nil, nil, err + } + rt := &http3.SingleDestinationRoundTripper{ + Connection: qconn, + StreamHijacker: func(ft http3.FrameType, connTracingID quic.ConnectionTracingID, str quic.Stream, e error) (hijacked bool, err error) { + if isWebTransportError(e) { + return true, nil } - if !settings.EnableDatagram { - return errors.New("server didn't enable HTTP/3 datagram support") + if ft != webTransportFrameType { + return false, nil } - if settings.Other == nil { - return errNoWebTransport + id, err := quicvarint.Read(quicvarint.NewReader(str)) + if err != nil { + if isWebTransportError(err) { + return true, nil + } + return false, err } - s, ok := settings.Other[settingsEnableWebtransport] - if !ok || s != 1 { - return errNoWebTransport + d.conns.AddStream(connTracingID, str, sessionID(id)) + return true, nil + }, + UniStreamHijacker: func(st http3.StreamType, connTracingID quic.ConnectionTracingID, str quic.ReceiveStream, err error) (hijacked bool) { + if st != webTransportUniStreamType && !isWebTransportError(err) { + return false } - return nil + d.conns.AddUniStream(connTracingID, str) + return true }, - }) + } + + conn := rt.Start() + requestStr, err := rt.OpenRequestStream(ctx) // TODO: put this on the Connection (maybe introduce a ClientConnection?) + if err != nil { + return nil, nil, err + } + if err := requestStr.SendRequestHeader(req); err != nil { + return nil, nil, err + } + <-conn.ReceivedSettings() // TODO: select + settings := conn.Settings() // TODO: instead of putting the settings on the SingleDestinationRoundTripper, create a way to retrieve the Connection instead + if !settings.EnableExtendedConnect { + return nil, nil, errors.New("server didn't enable Extended CONNECT") + } + if !settings.EnableDatagram { + return nil, nil, errors.New("server didn't enable HTTP/3 datagram support") + } + if settings.Other == nil { + return nil, nil, errNoWebTransport + } + s, ok := settings.Other[settingsEnableWebtransport] + if !ok || s != 1 { + return nil, nil, errNoWebTransport + } + + rsp, err := requestStr.ReadResponse() if err != nil { return nil, nil, err } if rsp.StatusCode < 200 || rsp.StatusCode >= 300 { return rsp, nil, fmt.Errorf("received status %d", rsp.StatusCode) } - str := rsp.Body.(http3.HTTPStreamer).HTTPStream() - conn := d.conns.AddSession( - rsp.Body.(http3.Hijacker).StreamCreator(), - sessionID(str.StreamID()), - str, - ) - return rsp, conn, nil + return rsp, d.conns.AddSession(conn, sessionID(requestStr.StreamID()), requestStr), nil } func (d *Dialer) Close() error { diff --git a/client_test.go b/client_test.go index 4ea1e86..03d47f1 100644 --- a/client_test.go +++ b/client_test.go @@ -107,11 +107,7 @@ func TestClientInvalidResponseHandling(t *testing.T) { } }() - d := webtransport.Dialer{ - RoundTripper: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - }, - } + d := webtransport.Dialer{TLSClientConfig: &tls.Config{RootCAs: certPool}} _, _, err = d.Dial(context.Background(), fmt.Sprintf("https://localhost:%d", s.Addr().(*net.UDPAddr).Port), nil) require.Error(t, err) var sErr error @@ -163,7 +159,7 @@ func TestClientInvalidSettingsHandling(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { tlsConf := tlsConf.Clone() - tlsConf.NextProtos = []string{"h3"} + tlsConf.NextProtos = []string{http3.NextProtoH3} s, err := quic.ListenAddr("localhost:0", tlsConf, &quic.Config{EnableDatagrams: true}) require.NoError(t, err) go func() { @@ -176,11 +172,7 @@ func TestClientInvalidSettingsHandling(t *testing.T) { require.NoError(t, err) }() - d := webtransport.Dialer{ - RoundTripper: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - }, - } + d := webtransport.Dialer{TLSClientConfig: &tls.Config{RootCAs: certPool}} _, _, err = d.Dial(context.Background(), fmt.Sprintf("https://localhost:%d", s.Addr().(*net.UDPAddr).Port), nil) require.Error(t, err) require.ErrorContains(t, err, tc.errorStr) @@ -208,15 +200,13 @@ func TestClientReorderedUpgrade(t *testing.T) { go s.Serve(udpConn) d := webtransport.Dialer{ - RoundTripper: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - conn, err := quic.DialAddrEarly(ctx, addr, tlsCfg, cfg) - if err != nil { - return nil, err - } - return &requestStreamDelayingConn{done: blockUpgrade, EarlyConnection: conn}, nil - }, + TLSClientConfig: &tls.Config{RootCAs: certPool}, + DialAddr: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + conn, err := quic.DialAddrEarly(ctx, addr, tlsCfg, cfg) + if err != nil { + return nil, err + } + return &requestStreamDelayingConn{done: blockUpgrade, EarlyConnection: conn}, nil }, } connChan := make(chan *webtransport.Session) diff --git a/go.mod b/go.mod index 4fff1f1..02702a8 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/quic-go/webtransport-go go 1.21 require ( - github.com/quic-go/quic-go v0.42.0 + github.com/quic-go/quic-go v0.42.1-0.20240411165505-da410a7b5935 github.com/stretchr/testify v1.8.0 go.uber.org/mock v0.4.0 ) diff --git a/go.sum b/go.sum index 3e1411b..e8cf2c3 100644 --- a/go.sum +++ b/go.sum @@ -92,8 +92,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM= -github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= +github.com/quic-go/quic-go v0.42.1-0.20240411165505-da410a7b5935 h1:gKMPe5jl70yeWH2AW2eHZ4Mva+rmxKIfOG2MKv6tmaU= +github.com/quic-go/quic-go v0.42.1-0.20240411165505-da410a7b5935/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= diff --git a/mock_connection_test.go b/mock_connection_test.go new file mode 100644 index 0000000..bd4608a --- /dev/null +++ b/mock_connection_test.go @@ -0,0 +1,260 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/http3 (interfaces: Connection) +// +// Generated by this command: +// +// mockgen -package webtransport -destination mock_connection_test.go github.com/quic-go/quic-go/http3 Connection +// + +// Package webtransport is a generated GoMock package. +package webtransport + +import ( + context "context" + net "net" + reflect "reflect" + + quic "github.com/quic-go/quic-go" + http3 "github.com/quic-go/quic-go/http3" + gomock "go.uber.org/mock/gomock" +) + +// MockConnection is a mock of Connection interface. +type MockConnection struct { + ctrl *gomock.Controller + recorder *MockConnectionMockRecorder +} + +// MockConnectionMockRecorder is the mock recorder for MockConnection. +type MockConnectionMockRecorder struct { + mock *MockConnection +} + +// NewMockConnection creates a new mock instance. +func NewMockConnection(ctrl *gomock.Controller) *MockConnection { + mock := &MockConnection{ctrl: ctrl} + mock.recorder = &MockConnectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnection) EXPECT() *MockConnectionMockRecorder { + return m.recorder +} + +// AcceptStream mocks base method. +func (m *MockConnection) AcceptStream(arg0 context.Context) (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptStream", arg0) + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptStream indicates an expected call of AcceptStream. +func (mr *MockConnectionMockRecorder) AcceptStream(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockConnection)(nil).AcceptStream), arg0) +} + +// AcceptUniStream mocks base method. +func (m *MockConnection) AcceptUniStream(arg0 context.Context) (quic.ReceiveStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) + ret0, _ := ret[0].(quic.ReceiveStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptUniStream indicates an expected call of AcceptUniStream. +func (mr *MockConnectionMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockConnection)(nil).AcceptUniStream), arg0) +} + +// CloseWithError mocks base method. +func (m *MockConnection) CloseWithError(arg0 quic.ApplicationErrorCode, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseWithError indicates an expected call of CloseWithError. +func (mr *MockConnectionMockRecorder) CloseWithError(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockConnection)(nil).CloseWithError), arg0, arg1) +} + +// ConnectionState mocks base method. +func (m *MockConnection) ConnectionState() quic.ConnectionState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectionState") + ret0, _ := ret[0].(quic.ConnectionState) + return ret0 +} + +// ConnectionState indicates an expected call of ConnectionState. +func (mr *MockConnectionMockRecorder) ConnectionState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockConnection)(nil).ConnectionState)) +} + +// Context mocks base method. +func (m *MockConnection) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockConnectionMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockConnection)(nil).Context)) +} + +// LocalAddr mocks base method. +func (m *MockConnection) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockConnectionMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockConnection)(nil).LocalAddr)) +} + +// OpenStream mocks base method. +func (m *MockConnection) OpenStream() (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStream") + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStream indicates an expected call of OpenStream. +func (mr *MockConnectionMockRecorder) OpenStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockConnection)(nil).OpenStream)) +} + +// OpenStreamSync mocks base method. +func (m *MockConnection) OpenStreamSync(arg0 context.Context) (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStreamSync", arg0) + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStreamSync indicates an expected call of OpenStreamSync. +func (mr *MockConnectionMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockConnection)(nil).OpenStreamSync), arg0) +} + +// OpenUniStream mocks base method. +func (m *MockConnection) OpenUniStream() (quic.SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStream") + ret0, _ := ret[0].(quic.SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStream indicates an expected call of OpenUniStream. +func (mr *MockConnectionMockRecorder) OpenUniStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockConnection)(nil).OpenUniStream)) +} + +// OpenUniStreamSync mocks base method. +func (m *MockConnection) OpenUniStreamSync(arg0 context.Context) (quic.SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) + ret0, _ := ret[0].(quic.SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. +func (mr *MockConnectionMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockConnection)(nil).OpenUniStreamSync), arg0) +} + +// ReceiveDatagram mocks base method. +func (m *MockConnection) ReceiveDatagram(arg0 context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveDatagram", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveDatagram indicates an expected call of ReceiveDatagram. +func (mr *MockConnectionMockRecorder) ReceiveDatagram(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveDatagram", reflect.TypeOf((*MockConnection)(nil).ReceiveDatagram), arg0) +} + +// ReceivedSettings mocks base method. +func (m *MockConnection) ReceivedSettings() <-chan struct{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceivedSettings") + ret0, _ := ret[0].(<-chan struct{}) + return ret0 +} + +// ReceivedSettings indicates an expected call of ReceivedSettings. +func (mr *MockConnectionMockRecorder) ReceivedSettings() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedSettings", reflect.TypeOf((*MockConnection)(nil).ReceivedSettings)) +} + +// RemoteAddr mocks base method. +func (m *MockConnection) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockConnectionMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockConnection)(nil).RemoteAddr)) +} + +// SendDatagram mocks base method. +func (m *MockConnection) SendDatagram(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendDatagram", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendDatagram indicates an expected call of SendDatagram. +func (mr *MockConnectionMockRecorder) SendDatagram(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDatagram", reflect.TypeOf((*MockConnection)(nil).SendDatagram), arg0) +} + +// Settings mocks base method. +func (m *MockConnection) Settings() *http3.Settings { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Settings") + ret0, _ := ret[0].(*http3.Settings) + return ret0 +} + +// Settings indicates an expected call of Settings. +func (mr *MockConnectionMockRecorder) Settings() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Settings", reflect.TypeOf((*MockConnection)(nil).Settings)) +} diff --git a/mock_stream_creator_test.go b/mock_stream_creator_test.go deleted file mode 100644 index f1614b3..0000000 --- a/mock_stream_creator_test.go +++ /dev/null @@ -1,158 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/http3 (interfaces: StreamCreator) -// -// Generated by this command: -// -// mockgen -package webtransport -destination mock_stream_creator_test.go github.com/quic-go/quic-go/http3 StreamCreator -// - -// Package webtransport is a generated GoMock package. -package webtransport - -import ( - context "context" - net "net" - reflect "reflect" - - quic "github.com/quic-go/quic-go" - gomock "go.uber.org/mock/gomock" -) - -// MockStreamCreator is a mock of StreamCreator interface. -type MockStreamCreator struct { - ctrl *gomock.Controller - recorder *MockStreamCreatorMockRecorder -} - -// MockStreamCreatorMockRecorder is the mock recorder for MockStreamCreator. -type MockStreamCreatorMockRecorder struct { - mock *MockStreamCreator -} - -// NewMockStreamCreator creates a new mock instance. -func NewMockStreamCreator(ctrl *gomock.Controller) *MockStreamCreator { - mock := &MockStreamCreator{ctrl: ctrl} - mock.recorder = &MockStreamCreatorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStreamCreator) EXPECT() *MockStreamCreatorMockRecorder { - return m.recorder -} - -// ConnectionState mocks base method. -func (m *MockStreamCreator) ConnectionState() quic.ConnectionState { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(quic.ConnectionState) - return ret0 -} - -// ConnectionState indicates an expected call of ConnectionState. -func (mr *MockStreamCreatorMockRecorder) ConnectionState() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockStreamCreator)(nil).ConnectionState)) -} - -// Context mocks base method. -func (m *MockStreamCreator) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockStreamCreatorMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamCreator)(nil).Context)) -} - -// LocalAddr mocks base method. -func (m *MockStreamCreator) LocalAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LocalAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// LocalAddr indicates an expected call of LocalAddr. -func (mr *MockStreamCreatorMockRecorder) LocalAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockStreamCreator)(nil).LocalAddr)) -} - -// OpenStream mocks base method. -func (m *MockStreamCreator) OpenStream() (quic.Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStream") - ret0, _ := ret[0].(quic.Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStream indicates an expected call of OpenStream. -func (mr *MockStreamCreatorMockRecorder) OpenStream() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamCreator)(nil).OpenStream)) -} - -// OpenStreamSync mocks base method. -func (m *MockStreamCreator) OpenStreamSync(arg0 context.Context) (quic.Stream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStreamSync", arg0) - ret0, _ := ret[0].(quic.Stream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockStreamCreatorMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamCreator)(nil).OpenStreamSync), arg0) -} - -// OpenUniStream mocks base method. -func (m *MockStreamCreator) OpenUniStream() (quic.SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStream") - ret0, _ := ret[0].(quic.SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockStreamCreatorMockRecorder) OpenUniStream() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockStreamCreator)(nil).OpenUniStream)) -} - -// OpenUniStreamSync mocks base method. -func (m *MockStreamCreator) OpenUniStreamSync(arg0 context.Context) (quic.SendStream, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) - ret0, _ := ret[0].(quic.SendStream) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockStreamCreatorMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamCreator)(nil).OpenUniStreamSync), arg0) -} - -// RemoteAddr mocks base method. -func (m *MockStreamCreator) RemoteAddr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoteAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockStreamCreatorMockRecorder) RemoteAddr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockStreamCreator)(nil).RemoteAddr)) -} diff --git a/server.go b/server.go index 90cf1c3..16718f4 100644 --- a/server.go +++ b/server.go @@ -80,7 +80,7 @@ func (s *Server) init() error { if s.H3.StreamHijacker != nil { return errors.New("StreamHijacker already set") } - s.H3.StreamHijacker = func(ft http3.FrameType, qconn quic.Connection, str quic.Stream, err error) (bool /* hijacked */, error) { + s.H3.StreamHijacker = func(ft http3.FrameType, connTracingID quic.ConnectionTracingID, str quic.Stream, err error) (bool /* hijacked */, error) { if isWebTransportError(err) { return true, nil } @@ -96,14 +96,14 @@ func (s *Server) init() error { } return false, err } - s.conns.AddStream(qconn, str, sessionID(id)) + s.conns.AddStream(connTracingID, str, sessionID(id)) return true, nil } - s.H3.UniStreamHijacker = func(st http3.StreamType, qconn quic.Connection, str quic.ReceiveStream, err error) (hijacked bool) { + s.H3.UniStreamHijacker = func(st http3.StreamType, connTracingID quic.ConnectionTracingID, str quic.ReceiveStream, err error) (hijacked bool) { if st != webTransportUniStreamType && !isWebTransportError(err) { return false } - s.conns.AddUniStream(qconn, str) + s.conns.AddUniStream(connTracingID, str) return true } return nil @@ -172,21 +172,14 @@ func (s *Server) Upgrade(w http.ResponseWriter, r *http.Request) (*Session, erro w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() - httpStreamer, ok := r.Body.(http3.HTTPStreamer) - if !ok { // should never happen, unless quic-go changed the API - return nil, errors.New("failed to take over HTTP stream") - } + httpStreamer := r.Body.(http3.HTTPStreamer) str := httpStreamer.HTTPStream() sID := sessionID(str.StreamID()) - hijacker, ok := w.(http3.Hijacker) - if !ok { // should never happen, unless quic-go changed the API - return nil, errors.New("failed to hijack") - } return s.conns.AddSession( - hijacker.StreamCreator(), + w.(http3.Hijacker).Connection(), sID, - r.Body.(http3.HTTPStreamer).HTTPStream(), + httpStreamer.HTTPStream(), ), nil } diff --git a/server_test.go b/server_test.go index 39dc322..ea58e72 100644 --- a/server_test.go +++ b/server_test.go @@ -22,6 +22,8 @@ import ( "github.com/stretchr/testify/require" ) +const webTransportFrameType = 0x41 + func scaleDuration(d time.Duration) time.Duration { if os.Getenv("CI") != "" { return 5 * d @@ -67,12 +69,12 @@ func newWebTransportRequest(t *testing.T, addr string) *http.Request { } } -func createStreamAndWrite(t *testing.T, qconn http3.StreamCreator, sessionID uint64, data []byte) quic.Stream { +func createStreamAndWrite(t *testing.T, conn quic.Connection, sessionID uint64, data []byte) quic.Stream { t.Helper() - str, err := qconn.OpenStream() + str, err := conn.OpenStream() require.NoError(t, err) var buf []byte - buf = quicvarint.Append(buf, 0x41) + buf = quicvarint.Append(buf, webTransportFrameType) buf = quicvarint.Append(buf, sessionID) // stream ID of the stream used to establish the WebTransport session. buf = append(buf, data...) _, err = str.Write(buf) @@ -96,28 +98,30 @@ func TestServerReorderedUpgradeRequest(t *testing.T) { port := udpConn.LocalAddr().(*net.UDPAddr).Port go s.Serve(udpConn) - rt := http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - } - defer rt.Close() - // This sends a request, so that we can hijack the connection. Stream ID: 0. - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/", port), nil) - require.NoError(t, err) - rsp, err := rt.RoundTrip(req) + cconn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", port), + &tls.Config{RootCAs: certPool, NextProtos: []string{http3.NextProtoH3}}, + &quic.Config{EnableDatagrams: true}, + ) require.NoError(t, err) - qconn := rsp.Body.(http3.Hijacker).StreamCreator() - // Open a new stream for a WebTransport session we'll establish later. Stream ID: 4. - createStreamAndWrite(t, qconn, 8, []byte("foobar")) + // Open a new stream for a WebTransport session we'll establish later. Stream ID: 0. + createStreamAndWrite(t, cconn, 4, []byte("foobar")) + rt := http3.SingleDestinationRoundTripper{ + Connection: cconn, + EnableDatagrams: true, + } // make sure this request actually arrives first time.Sleep(scaleDuration(50 * time.Millisecond)) - rsp, err = rt.RoundTripOpt( - newWebTransportRequest(t, fmt.Sprintf("https://localhost:%d/webtransport", port)), - http3.RoundTripOpt{DontCloseRequestStream: true}, - ) + // Create a new WebTransport session. Stream ID: 4. + str, err := rt.OpenRequestStream(context.Background()) require.NoError(t, err) - require.Equal(t, 200, rsp.StatusCode) + require.NoError(t, str.SendRequestHeader(newWebTransportRequest(t, fmt.Sprintf("https://localhost:%d/webtransport", port)))) + rsp, err := str.ReadResponse() + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) sconn := <-connChan defer sconn.CloseWithError(0, "") ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) @@ -129,7 +133,7 @@ func TestServerReorderedUpgradeRequest(t *testing.T) { require.Equal(t, []byte("foobar"), data) // Establish another stream and make sure it's accepted now. - createStreamAndWrite(t, qconn, 8, []byte("raboof")) + createStreamAndWrite(t, cconn, 4, []byte("raboof")) ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() sstr, err = sconn.AcceptStream(ctx) @@ -142,7 +146,7 @@ func TestServerReorderedUpgradeRequest(t *testing.T) { func TestServerReorderedUpgradeRequestTimeout(t *testing.T) { timeout := scaleDuration(100 * time.Millisecond) s := webtransport.Server{ - H3: http3.Server{TLSConfig: tlsConf}, + H3: http3.Server{TLSConfig: tlsConf, EnableDatagrams: true}, StreamReorderingTimeout: timeout, } defer s.Close() @@ -156,21 +160,24 @@ func TestServerReorderedUpgradeRequestTimeout(t *testing.T) { port := udpConn.LocalAddr().(*net.UDPAddr).Port go s.Serve(udpConn) - rt := http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - } - defer rt.Close() - // This sends a request, so that we can hijack the connection. Stream ID: 0. - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/", port), nil) - require.NoError(t, err) - rsp, err := rt.RoundTrip(req) + cconn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", port), + &tls.Config{RootCAs: certPool, NextProtos: []string{http3.NextProtoH3}}, + &quic.Config{EnableDatagrams: true}, + ) require.NoError(t, err) - qconn := rsp.Body.(http3.Hijacker).StreamCreator() - // Open a new stream for a WebTransport session we'll establish later. Stream ID: 4. - str := createStreamAndWrite(t, qconn, 8, []byte("foobar")) + + // Open a new stream for a WebTransport session we'll establish later. Stream ID: 0. + str := createStreamAndWrite(t, cconn, 4, []byte("foobar")) time.Sleep(2 * timeout) + rt := http3.SingleDestinationRoundTripper{ + Connection: cconn, + EnableDatagrams: true, + } + // Reordering was too long. The stream should now have been reset by the server. _, err = str.Read([]byte{0}) var streamErr *quic.StreamError @@ -178,12 +185,12 @@ func TestServerReorderedUpgradeRequestTimeout(t *testing.T) { require.Equal(t, webtransport.WebTransportBufferedStreamRejectedErrorCode, streamErr.ErrorCode) // Now establish the session. Make sure we don't accept the stream. - rsp, err = rt.RoundTripOpt( - newWebTransportRequest(t, fmt.Sprintf("https://localhost:%d/webtransport", port)), - http3.RoundTripOpt{DontCloseRequestStream: true}, - ) + requestStr, err := rt.OpenRequestStream(context.Background()) require.NoError(t, err) - require.Equal(t, 200, rsp.StatusCode) + require.NoError(t, requestStr.SendRequestHeader(newWebTransportRequest(t, fmt.Sprintf("https://localhost:%d/webtransport", port)))) + rsp, err := requestStr.ReadResponse() + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) sconn := <-connChan defer sconn.CloseWithError(0, "") ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) @@ -192,7 +199,7 @@ func TestServerReorderedUpgradeRequestTimeout(t *testing.T) { require.ErrorIs(t, err, context.DeadlineExceeded) // Establish another stream and make sure it's accepted now. - createStreamAndWrite(t, qconn, 8, []byte("raboof")) + createStreamAndWrite(t, cconn, 4, []byte("raboof")) ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() sstr, err := sconn.AcceptStream(ctx) @@ -205,7 +212,7 @@ func TestServerReorderedUpgradeRequestTimeout(t *testing.T) { func TestServerReorderedMultipleStreams(t *testing.T) { timeout := scaleDuration(150 * time.Millisecond) s := webtransport.Server{ - H3: http3.Server{TLSConfig: tlsConf}, + H3: http3.Server{TLSConfig: tlsConf, EnableDatagrams: true}, StreamReorderingTimeout: timeout, } defer s.Close() @@ -219,24 +226,21 @@ func TestServerReorderedMultipleStreams(t *testing.T) { port := udpConn.LocalAddr().(*net.UDPAddr).Port go s.Serve(udpConn) - rt := http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - } - defer rt.Close() - // This sends a request, so that we can hijack the connection. Stream ID: 0. - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/", port), nil) - require.NoError(t, err) - rsp, err := rt.RoundTrip(req) + cconn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", port), + &tls.Config{RootCAs: certPool, NextProtos: []string{http3.NextProtoH3}}, + &quic.Config{EnableDatagrams: true}, + ) require.NoError(t, err) - qconn := rsp.Body.(http3.Hijacker).StreamCreator() start := time.Now() - // Open a new stream for a WebTransport session we'll establish later. Stream ID: 4. - str1 := createStreamAndWrite(t, qconn, 12, []byte("foobar")) + // Open a new stream for a WebTransport session we'll establish later. Stream ID: 0. + str1 := createStreamAndWrite(t, cconn, 8, []byte("foobar")) // After a while, open another stream. time.Sleep(timeout / 2) - // Open a new stream for a WebTransport session we'll establish later. Stream ID: 8. - createStreamAndWrite(t, qconn, 12, []byte("raboof")) + // Open a new stream for a WebTransport session we'll establish later. Stream ID: 4. + createStreamAndWrite(t, cconn, 8, []byte("raboof")) // Reordering was too long. The stream should now have been reset by the server. _, err = str1.Read([]byte{0}) @@ -247,13 +251,17 @@ func TestServerReorderedMultipleStreams(t *testing.T) { require.GreaterOrEqual(t, took, timeout) require.Less(t, took, timeout*5/4) + rt := http3.SingleDestinationRoundTripper{ + Connection: cconn, + EnableDatagrams: true, + } // Now establish the session. Make sure we don't accept the stream. - rsp, err = rt.RoundTripOpt( - newWebTransportRequest(t, fmt.Sprintf("https://localhost:%d/webtransport", port)), - http3.RoundTripOpt{DontCloseRequestStream: true}, - ) + requestStr, err := rt.OpenRequestStream(context.Background()) + require.NoError(t, err) + require.NoError(t, requestStr.SendRequestHeader(newWebTransportRequest(t, fmt.Sprintf("https://localhost:%d/webtransport", port)))) + rsp, err := requestStr.ReadResponse() require.NoError(t, err) - require.Equal(t, 200, rsp.StatusCode) + require.Equal(t, http.StatusOK, rsp.StatusCode) sconn := <-connChan defer sconn.CloseWithError(0, "") ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) diff --git a/session.go b/session.go index 3c67fd2..33e2425 100644 --- a/session.go +++ b/session.go @@ -62,7 +62,7 @@ func (q *acceptQueue[T]) Chan() <-chan struct{} { return q.c } type Session struct { sessionID sessionID - qconn http3.StreamCreator + qconn http3.Connection requestStr quic.Stream streamHdr []byte @@ -82,8 +82,8 @@ type Session struct { streams streamsMap } -func newSession(sessionID sessionID, qconn http3.StreamCreator, requestStr quic.Stream) *Session { - tracingID := qconn.Context().Value(quic.ConnectionTracingKey).(uint64) +func newSession(sessionID sessionID, qconn http3.Connection, requestStr quic.Stream) *Session { + tracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) ctx, ctxCancel := context.WithCancel(context.WithValue(context.Background(), quic.ConnectionTracingKey, tracingID)) c := &Session{ sessionID: sessionID, diff --git a/session_manager.go b/session_manager.go index 2dbb738..5361019 100644 --- a/session_manager.go +++ b/session_manager.go @@ -25,13 +25,13 @@ type sessionManager struct { timeout time.Duration mx sync.Mutex - conns map[http3.StreamCreator]map[sessionID]*session + conns map[quic.ConnectionTracingID]map[sessionID]*session } func newSessionManager(timeout time.Duration) *sessionManager { m := &sessionManager{ timeout: timeout, - conns: make(map[http3.StreamCreator]map[sessionID]*session), + conns: make(map[quic.ConnectionTracingID]map[sessionID]*session), } m.ctx, m.ctxCancel = context.WithCancel(context.Background()) return m @@ -41,8 +41,8 @@ func newSessionManager(timeout time.Duration) *sessionManager { // If the WebTransport session has not yet been established, // it starts a new go routine and waits for establishment of the session. // If that takes longer than timeout, the stream is reset. -func (m *sessionManager) AddStream(qconn http3.StreamCreator, str quic.Stream, id sessionID) { - sess, isExisting := m.getOrCreateSession(qconn, id) +func (m *sessionManager) AddStream(connTracingID quic.ConnectionTracingID, str quic.Stream, id sessionID) { + sess, isExisting := m.getOrCreateSession(connTracingID, id) if isExisting { sess.conn.addIncomingStream(str) return @@ -60,19 +60,19 @@ func (m *sessionManager) AddStream(qconn http3.StreamCreator, str quic.Stream, i // Once no more streams are waiting for this session to be established, // and this session is still outstanding, delete it from the map. if sess.counter == 0 && sess.conn == nil { - m.maybeDelete(qconn, id) + m.maybeDelete(connTracingID, id) } }() } -func (m *sessionManager) maybeDelete(qconn http3.StreamCreator, id sessionID) { - sessions, ok := m.conns[qconn] +func (m *sessionManager) maybeDelete(connTracingID quic.ConnectionTracingID, id sessionID) { + sessions, ok := m.conns[connTracingID] if !ok { // should never happen return } delete(sessions, id) if len(sessions) == 0 { - delete(m.conns, qconn) + delete(m.conns, connTracingID) } } @@ -80,14 +80,14 @@ func (m *sessionManager) maybeDelete(qconn http3.StreamCreator, id sessionID) { // If the WebTransport session has not yet been established, // it starts a new go routine and waits for establishment of the session. // If that takes longer than timeout, the stream is reset. -func (m *sessionManager) AddUniStream(qconn http3.StreamCreator, str quic.ReceiveStream) { +func (m *sessionManager) AddUniStream(connTracingID quic.ConnectionTracingID, str quic.ReceiveStream) { idv, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { str.CancelRead(1337) } id := sessionID(idv) - sess, isExisting := m.getOrCreateSession(qconn, id) + sess, isExisting := m.getOrCreateSession(connTracingID, id) if isExisting { sess.conn.addIncomingUniStream(str) return @@ -105,19 +105,19 @@ func (m *sessionManager) AddUniStream(qconn http3.StreamCreator, str quic.Receiv // Once no more streams are waiting for this session to be established, // and this session is still outstanding, delete it from the map. if sess.counter == 0 && sess.conn == nil { - m.maybeDelete(qconn, id) + m.maybeDelete(connTracingID, id) } }() } -func (m *sessionManager) getOrCreateSession(qconn http3.StreamCreator, id sessionID) (sess *session, existed bool) { +func (m *sessionManager) getOrCreateSession(connTracingID quic.ConnectionTracingID, id sessionID) (sess *session, existed bool) { m.mx.Lock() defer m.mx.Unlock() - sessions, ok := m.conns[qconn] + sessions, ok := m.conns[connTracingID] if !ok { sessions = make(map[sessionID]*session) - m.conns[qconn] = sessions + m.conns[connTracingID] = sessions } sess, ok = sessions[id] @@ -164,16 +164,17 @@ func (m *sessionManager) handleUniStream(str quic.ReceiveStream, sess *session) } // AddSession adds a new WebTransport session. -func (m *sessionManager) AddSession(qconn http3.StreamCreator, id sessionID, requestStr quic.Stream) *Session { +func (m *sessionManager) AddSession(qconn http3.Connection, id sessionID, requestStr quic.Stream) *Session { conn := newSession(id, qconn, requestStr) + connTracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) m.mx.Lock() defer m.mx.Unlock() - sessions, ok := m.conns[qconn] + sessions, ok := m.conns[connTracingID] if !ok { sessions = make(map[sessionID]*session) - m.conns[qconn] = sessions + m.conns[connTracingID] = sessions } if sess, ok := sessions[id]; ok { // We might already have an entry of this session. diff --git a/session_test.go b/session_test.go index 8071e48..1046244 100644 --- a/session_test.go +++ b/session_test.go @@ -12,7 +12,7 @@ import ( "go.uber.org/mock/gomock" ) -//go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_stream_creator_test.go github.com/quic-go/quic-go/http3 StreamCreator" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_connection_test.go github.com/quic-go/quic-go/http3 Connection" //go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_stream_test.go github.com/quic-go/quic-go Stream && cat mock_stream_test.go | sed s@protocol\\.StreamID@quic.StreamID@g | sed s@qerr\\.StreamErrorCode@quic.StreamErrorCode@g > tmp.go && mv tmp.go mock_stream_test.go && goimports -w mock_stream_test.go" type mockRequestStream struct { @@ -41,8 +41,8 @@ func (s *mockRequestStream) Write(b []byte) (int, error) { return len(b), nil } func TestCloseStreamsOnClose(t *testing.T) { ctrl := gomock.NewController(t) - mockSess := NewMockStreamCreator(ctrl) - mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, uint64(1337))) + mockSess := NewMockConnection(ctrl) + mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1337))) sess := newSession(42, mockSess, newMockRequestStream(ctrl)) str := NewMockStream(ctrl) @@ -66,8 +66,8 @@ func TestOpenStreamSyncCancel(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSess := NewMockStreamCreator(ctrl) - mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, uint64(1337))) + mockSess := NewMockConnection(ctrl) + mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1337))) sess := newSession(42, mockSess, newMockRequestStream(ctrl)) defer sess.CloseWithError(0, "") @@ -101,8 +101,8 @@ func TestAddStreamAfterSessionClose(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSess := NewMockStreamCreator(ctrl) - mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, uint64(1337))) + mockSess := NewMockConnection(ctrl) + mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1337))) sess := newSession(42, mockSess, newMockRequestStream(ctrl)) require.NoError(t, sess.CloseWithError(0, "")) @@ -121,8 +121,8 @@ func TestOpenStreamAfterSessionClose(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSess := NewMockStreamCreator(ctrl) - mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, uint64(1337))) + mockSess := NewMockConnection(ctrl) + mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1337))) wait := make(chan struct{}) streamOpen := make(chan struct{}) mockSess.EXPECT().OpenStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { @@ -153,8 +153,8 @@ func TestOpenUniStreamAfterSessionClose(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockSess := NewMockStreamCreator(ctrl) - mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, uint64(1337))) + mockSess := NewMockConnection(ctrl) + mockSess.EXPECT().Context().Return(context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1337))) wait := make(chan struct{}) streamOpen := make(chan struct{}) mockSess.EXPECT().OpenUniStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.SendStream, error) { diff --git a/webtransport_test.go b/webtransport_test.go index 2259a33..f89f2e0 100644 --- a/webtransport_test.go +++ b/webtransport_test.go @@ -71,17 +71,15 @@ func establishSession(t *testing.T, handler func(*webtransport.Session)) (sess * s := &webtransport.Server{ H3: http3.Server{ TLSConfig: tlsConf, - QuicConfig: &quic.Config{Tracer: getQlogger(t), EnableDatagrams: true}, + QUICConfig: &quic.Config{Tracer: getQlogger(t), EnableDatagrams: true}, }, } addHandler(t, s, handler) addr, closeServer := runServer(t, s) d := webtransport.Dialer{ - RoundTripper: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - QuicConfig: &quic.Config{Tracer: getQlogger(t), EnableDatagrams: true}, - }, + TLSClientConfig: &tls.Config{RootCAs: certPool}, + QUICConfig: &quic.Config{Tracer: getQlogger(t), EnableDatagrams: true}, } defer d.Close() url := fmt.Sprintf("https://localhost:%d/webtransport", addr.Port) @@ -91,7 +89,7 @@ func establishSession(t *testing.T, handler func(*webtransport.Session)) (sess * return sess, func() { closeServer() s.Close() - d.RoundTripper.Close() + d.Close() } } @@ -221,9 +219,9 @@ func TestStreamsImmediateClose(t *testing.T) { t.Run("unidirectional", func(t *testing.T) { t.Run("client-initiated", func(t *testing.T) { - sess, closeServer := establishSession(t, func(c *webtransport.Session) { - defer c.CloseWithError(0, "") - str, err := c.AcceptUniStream(context.Background()) + sess, closeServer := establishSession(t, func(sess *webtransport.Session) { + defer sess.CloseWithError(0, "") + str, err := sess.AcceptUniStream(context.Background()) require.NoError(t, err) n, err := str.Read([]byte{0}) require.Zero(t, n) @@ -238,8 +236,8 @@ func TestStreamsImmediateClose(t *testing.T) { }) t.Run("server-initiated", func(t *testing.T) { - sess, closeServer := establishSession(t, func(c *webtransport.Session) { - str, err := c.OpenUniStream() + sess, closeServer := establishSession(t, func(sess *webtransport.Session) { + str, err := sess.OpenUniStream() require.NoError(t, err) require.NoError(t, str.Close()) }) @@ -342,10 +340,8 @@ func TestMultipleClients(t *testing.T) { go func() { defer wg.Done() d := webtransport.Dialer{ - RoundTripper: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - QuicConfig: &quic.Config{Tracer: getQlogger(t), EnableDatagrams: true}, - }, + TLSClientConfig: &tls.Config{RootCAs: certPool}, + QUICConfig: &quic.Config{Tracer: getQlogger(t), EnableDatagrams: true}, } defer d.Close() url := fmt.Sprintf("https://localhost:%d/webtransport", addr.Port) @@ -522,10 +518,8 @@ func TestCheckOrigin(t *testing.T) { defer closeServer() d := webtransport.Dialer{ - RoundTripper: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{RootCAs: certPool}, - QuicConfig: &quic.Config{Tracer: getQlogger(t), EnableDatagrams: true}, - }, + TLSClientConfig: &tls.Config{RootCAs: certPool}, + QUICConfig: &quic.Config{Tracer: getQlogger(t), EnableDatagrams: true}, } defer d.Close() url := fmt.Sprintf("https://localhost:%d/webtransport", addr.Port)