Skip to content

Commit

Permalink
feat: internalize transport connections handling
Browse files Browse the repository at this point in the history
  • Loading branch information
certaintls committed Jun 19, 2022
1 parent d14385d commit 6edcfff
Showing 1 changed file with 120 additions and 19 deletions.
139 changes: 120 additions & 19 deletions pkg/core/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ type HysteriaTransport struct {
disconnectFunc DisconnectFunc
}

type TransportServer struct {
transport *transport.ServerTransport
sendBPS, recvBPS uint64
congestionFactory CongestionFactory
disableUDP bool
aclEngine *acl.Engine

connectFunc ConnectFunc
disconnectFunc DisconnectFunc
tcpRequestFunc TCPRequestFunc
tcpErrorFunc TCPErrorFunc
udpRequestFunc UDPRequestFunc
udpErrorFunc UDPErrorFunc

listener quic.Listener
allStreams chan *quicConn
isListening bool
}

func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig *quic.Config, transport *transport.ServerTransport,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine,
obfuscator obfs.Obfuscator, connectFunc ConnectFunc, disconnectFunc DisconnectFunc,
Expand Down Expand Up @@ -198,7 +217,7 @@ func (t *HysteriaTransport) Listen() (net.Listener, error) {
if err != nil {
return nil, err
}
s := &Server{
s := &TransportServer{
listener: listener,
transport: t.transport,
sendBPS: t.sendBPS,
Expand All @@ -207,53 +226,135 @@ func (t *HysteriaTransport) Listen() (net.Listener, error) {
disableUDP: t.disableUDP,
connectFunc: t.connectFunc,
disconnectFunc: t.disconnectFunc,
allStreams: make(chan *quicConn),
isListening: false,
}

return s, nil
}

// Addr returns the listener's network address.
func (s *Server) Addr() net.Addr {
func (s *TransportServer) Addr() net.Addr {
return s.listener.Addr()
}

func (s *Server) Accept() (net.Conn, error) {
cs, err := s.listener.Accept(context.Background())
if err != nil {
return nil, err
func (s *TransportServer) Close() error {
s.isListening = false
return s.listener.Close()
}

func (s *TransportServer) Accept() (net.Conn, error) {
if !s.isListening {
s.isListening = true
go acceptConn(s)
}
// Return the next stream
select {
case stream := <-s.allStreams:
return stream, nil
}
}

// An internal goroutine for accepting connections. Then for each accepted
// connection, start a goroutine for handling the control stream & accepting
// streams. Put those streams into a channel
func acceptConn(s *TransportServer) {
for {
cs, err := s.listener.Accept(context.Background())
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return
}
go acceptStream(cs, s)
}
}

func acceptStream(cs quic.Connection, s *TransportServer) {
// Expect the client to create a control stream to send its own information
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
stream, err := cs.AcceptStream(ctx)
ctxCancel()
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return nil, err
return
}
// Handle the control stream
_, ok, _, err := s.handleControlStream(cs, stream)
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return nil, err
return
}
if !ok {
_ = cs.CloseWithError(closeErrorCodeAuth, "auth error")
return nil, err
return
}
// Close the control stream
stream.Close()

// Accept the next stream
stream, err = cs.AcceptStream(context.Background())
if err != nil {
return nil, err
}
for {
// Accept the next stream
stream, err = cs.AcceptStream(context.Background())
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return
}

conn := &quicConn{
Orig: stream,
PseudoLocalAddr: cs.LocalAddr(),
PseudoRemoteAddr: cs.RemoteAddr(),
conn := &quicConn{
Orig: stream,
PseudoLocalAddr: cs.LocalAddr(),
PseudoRemoteAddr: cs.RemoteAddr(),
}
s.allStreams <- conn
}
}

return conn, nil
// Auth & negotiate speed
// Copy from (s *Server) handleControlStream, TODO: refactor
func (s *TransportServer) handleControlStream(cs quic.Connection, stream quic.Stream) ([]byte, bool, bool, error) {
// Check version
vb := make([]byte, 1)
_, err := stream.Read(vb)
if err != nil {
return nil, false, false, err
}
if vb[0] != protocolVersion && vb[0] != protocolVersionV2 {
return nil, false, false, fmt.Errorf("unsupported protocol version %d, expecting %d/%d",
vb[0], protocolVersionV2, protocolVersion)
}
// Parse client hello
var ch clientHello
err = struc.Unpack(stream, &ch)
if err != nil {
return nil, false, false, err
}
// Speed
if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 {
return nil, false, false, errors.New("invalid rate from client")
}
serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS
if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
serverSendBPS = s.sendBPS
}
if s.recvBPS > 0 && serverRecvBPS > s.recvBPS {
serverRecvBPS = s.recvBPS
}
// Auth
ok, msg := s.connectFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS)
// Response
err = struc.Pack(stream, &serverHello{
OK: ok,
Rate: transmissionRate{
SendBPS: serverSendBPS,
RecvBPS: serverRecvBPS,
},
Message: msg,
})
if err != nil {
return nil, false, false, err
}
// Set the congestion accordingly
if ok && s.congestionFactory != nil {
cs.SetCongestionControl(s.congestionFactory(serverSendBPS))
}
return ch.Auth, ok, vb[0] == protocolVersionV2, nil
}

0 comments on commit 6edcfff

Please sign in to comment.