diff --git a/client.go b/client.go index 5e675fd3..b4e5f9b6 100644 --- a/client.go +++ b/client.go @@ -62,6 +62,7 @@ type Client struct { trMap *client.TransactionMap // thread-safe rto time.Duration // read-only relayedConn *client.UDPConn // protected by mutex *** + tcpAllocation *client.TCPAllocation // protected by mutex *** allocTryLock client.TryLock // thread-safe listenTryLock client.TryLock // thread-safe net transport.Net // read-only @@ -238,42 +239,34 @@ func (c *Client) SendBindingRequest() (net.Addr, error) { return c.SendBindingRequestTo(c.stunServ) } -// Allocate sends a TURN allocation request to the given transport address -func (c *Client) Allocate() (net.PacketConn, error) { - if err := c.allocTryLock.Lock(); err != nil { - return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) - } - defer c.allocTryLock.Unlock() - - relayedConn := c.relayedUDPConn() - if relayedConn != nil { - return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String()) - } +func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) { + var relayed proto.RelayedAddress + var lifetime proto.Lifetime + var nonce stun.Nonce msg, err := stun.Build( stun.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassRequest), - proto.RequestedTransport{Protocol: proto.ProtoUDP}, + proto.RequestedTransport{Protocol: protocol}, stun.Fingerprint, ) if err != nil { - return nil, err + return relayed, lifetime, nonce, err } trRes, err := c.PerformTransaction(msg, c.turnServ, false) if err != nil { - return nil, err + return relayed, lifetime, nonce, err } res := trRes.Msg // Anonymous allocate failed, trying to authenticate. - var nonce stun.Nonce if err = nonce.GetFrom(res); err != nil { - return nil, err + return relayed, lifetime, nonce, err } if err = c.realm.GetFrom(res); err != nil { - return nil, err + return relayed, lifetime, nonce, err } c.realm = append([]byte(nil), c.realm...) c.integrity = stun.NewLongTermIntegrity( @@ -283,7 +276,7 @@ func (c *Client) Allocate() (net.PacketConn, error) { msg, err = stun.Build( stun.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassRequest), - proto.RequestedTransport{Protocol: proto.ProtoUDP}, + proto.RequestedTransport{Protocol: protocol}, &c.username, &c.realm, &nonce, @@ -291,40 +284,93 @@ func (c *Client) Allocate() (net.PacketConn, error) { stun.Fingerprint, ) if err != nil { - return nil, err + return relayed, lifetime, nonce, err } trRes, err = c.PerformTransaction(msg, c.turnServ, false) if err != nil { - return nil, err + return relayed, lifetime, nonce, err } res = trRes.Msg if res.Type.Class == stun.ClassErrorResponse { var code stun.ErrorCodeAttribute if err = code.GetFrom(res); err == nil { - return nil, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113 + return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113 } - return nil, fmt.Errorf("%s", res.Type) //nolint:goerr113 + return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113 } // Getting relayed addresses from response. - var relayed proto.RelayedAddress if err := relayed.GetFrom(res); err != nil { + return relayed, lifetime, nonce, err + } + + // Getting lifetime from response + if err := lifetime.GetFrom(res); err != nil { + return relayed, lifetime, nonce, err + } + return relayed, lifetime, nonce, nil +} + +// Allocate sends a TURN allocation request to the given transport address +func (c *Client) Allocate() (net.PacketConn, error) { + if err := c.allocTryLock.Lock(); err != nil { + return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) + } + defer c.allocTryLock.Unlock() + + relayedConn := c.relayedUDPConn() + if relayedConn != nil { + return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String()) + } + + relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoUDP) + if err != nil { return nil, err } + relayedAddr := &net.UDPAddr{ IP: relayed.IP, Port: relayed.Port, } - // Getting lifetime from response - var lifetime proto.Lifetime - if err := lifetime.GetFrom(res); err != nil { + relayedConn = client.NewUDPConn(&client.ConnConfig{ + Observer: c, + RelayedAddr: relayedAddr, + Integrity: c.integrity, + Nonce: nonce, + Lifetime: lifetime.Duration, + Log: c.log, + }) + c.setRelayedUDPConn(relayedConn) + + return relayedConn, nil +} + +// Allocate TCP +func (c *Client) AllocateTCP() (*client.TCPAllocation, error) { + if err := c.allocTryLock.Lock(); err != nil { + return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) + } + defer c.allocTryLock.Unlock() + + allocation := c.getTCPAllocation() + if allocation != nil { + return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, allocation.Addr().String()) + } + + relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoTCP) + if err != nil { return nil, err } - relayedConn = client.NewUDPConn(&client.UDPConnConfig{ + relayedAddr := &net.TCPAddr{ + IP: relayed.IP, + Port: relayed.Port, + } + + allocation = client.NewTCPAllocation(&client.ConnConfig{ Observer: c, RelayedAddr: relayedAddr, Integrity: c.integrity, @@ -333,15 +379,26 @@ func (c *Client) Allocate() (net.PacketConn, error) { Log: c.log, }) - c.setRelayedUDPConn(relayedConn) + c.setTCPAllocation(allocation) - return relayedConn, nil + return allocation, nil } // CreatePermission Issues a CreatePermission request for the supplied addresses // as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9 func (c *Client) CreatePermission(addrs ...net.Addr) error { - return c.relayedUDPConn().CreatePermissions(addrs...) + if conn := c.relayedUDPConn(); conn != nil { + if err := conn.CreatePermissions(addrs...); err != nil { + return err + } + } + + if allocation := c.getTCPAllocation(); allocation != nil { + if err := allocation.CreatePermissions(addrs...); err != nil { + return err + } + } + return nil } // PerformTransaction performs STUN transaction @@ -387,6 +444,7 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult // (Called by UDPConn) func (c *Client) OnDeallocated(relayedAddr net.Addr) { c.setRelayedUDPConn(nil) + c.setTCPAllocation(nil) } // HandleInbound handles data received. @@ -445,7 +503,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { } if msg.Type.Class == stun.ClassIndication { - if msg.Type.Method == stun.MethodData { + switch msg.Type.Method { + case stun.MethodData: var peerAddr proto.PeerAddress if err := peerAddr.GetFrom(msg); err != nil { return err @@ -467,8 +526,32 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { c.log.Debug("no relayed conn allocated") return nil // silently discard } - relayedConn.HandleInbound(data, from) + case stun.MethodConnectionAttempt: + var peerAddr proto.PeerAddress + if err := peerAddr.GetFrom(msg); err != nil { + return err + } + + addr := &net.TCPAddr{ + IP: peerAddr.IP, + Port: peerAddr.Port, + } + + var cid proto.ConnectionID + if err := cid.GetFrom(msg); err != nil { + return err + } + + c.log.Debugf("connection attempt from %s", addr.String()) + + allocation := c.getTCPAllocation() + if allocation == nil { + c.log.Debug("no TCP allocation exists") + return nil // silently discard + } + + allocation.HandleConnectionAttempt(addr, cid) } return nil } @@ -579,3 +662,17 @@ func (c *Client) relayedUDPConn() *client.UDPConn { return c.relayedConn } + +func (c *Client) setTCPAllocation(alloc *client.TCPAllocation) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.tcpAllocation = alloc +} + +func (c *Client) getTCPAllocation() *client.TCPAllocation { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.tcpAllocation +} diff --git a/client_test.go b/client_test.go index e799f498..ac50c0ea 100644 --- a/client_test.go +++ b/client_test.go @@ -11,6 +11,7 @@ import ( "github.com/pion/logging" "github.com/pion/transport/v2/stdnet" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func createListeningTestClient(t *testing.T, loggerFactory logging.LoggerFactory) (*Client, net.PacketConn, bool) { @@ -187,3 +188,53 @@ func TestClientNonceExpiration(t *testing.T) { assert.NoError(t, conn.Close()) assert.NoError(t, server.Close()) } + +// Create a TCP-based allocation and verify allocation can be created +func TestTCPClient(t *testing.T) { + // Setup server + tcpListener, err := net.Listen("tcp4", "0.0.0.0:13478") + require.NoError(t, err) + + server, err := NewServer(ServerConfig{ + AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + return GenerateAuthKey(username, realm, "pass"), true + }, + ListenerConfigs: []ListenerConfig{ + { + Listener: tcpListener, + RelayAddressGenerator: &RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP("127.0.0.1"), + Address: "0.0.0.0", + }, + }, + }, + Realm: "pion.ly", + }) + require.NoError(t, err) + + // Setup clients + conn, err := net.Dial("tcp", "127.0.0.1:13478") + require.NoError(t, err) + + client, err := NewClient(&ClientConfig{ + Conn: NewSTUNConn(conn), + STUNServerAddr: "127.0.0.1:13478", + TURNServerAddr: "127.0.0.1:13478", + Username: "foo", + Password: "pass", + }) + require.NoError(t, err) + require.NoError(t, client.Listen()) + + allocation, err := client.AllocateTCP() + require.NoError(t, err) + + // TODO: Implement server side handling of Connect and ConnectionBind + // _, err = allocation.Dial(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}) + // assert.NoError(t, err) + + // Shutdown + require.NoError(t, allocation.Close()) + require.NoError(t, conn.Close()) + require.NoError(t, server.Close()) +} diff --git a/examples/turn-client/tcp_alloc/main.go b/examples/turn-client/tcp_alloc/main.go new file mode 100644 index 00000000..ed3c14e7 --- /dev/null +++ b/examples/turn-client/tcp_alloc/main.go @@ -0,0 +1,165 @@ +// Package main implements a TURN client with support for TCP +package main + +import ( + "bufio" + "flag" + "fmt" + "log" + "net" + "strconv" + "strings" + + "github.com/pion/logging" + "github.com/pion/turn/v2" +) + +func setupSignalingChannel(addrCh chan string, signaling bool, relayAddr string) { + addr := "127.0.0.1:5000" + if signaling { + go func() { + listen, err := net.Listen("tcp", addr) + if err != nil { + log.Panicf("Failed to create signaling server: %s", err) + } + defer listen.Close() + for { + conn, err := listen.Accept() + if err != nil { + log.Panicf("Failed to accept: %s", err) + } + go func() { + message, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + log.Panicf("Failed to read relayAddr: %s", err) + } + addrCh <- message[:len(message)-1] + }() + if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", relayAddr))); err != nil { + log.Panicf("Failed to write relayAddr: %s", err) + } + } + }() + } else { + conn, err := net.Dial("tcp", addr) + if err != nil { + log.Panicf("Error dialing: %s", err) + } + message, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + log.Panicf("Failed to read relayAddr: %s", err) + } + addrCh <- message[:len(message)-1] + if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", relayAddr))); err != nil { + log.Panicf("Failed to write relayAddr: %s", err) + } + } +} + +func main() { + host := flag.String("host", "", "TURN Server name.") + port := flag.Int("port", 3478, "Listening port.") + user := flag.String("user", "", "A pair of username and password (e.g. \"user=pass\")") + realm := flag.String("realm", "pion.ly", "Realm (defaults to \"pion.ly\")") + signaling := flag.Bool("signaling", false, "Whether to start signaling server otherwise connect") + + flag.Parse() + + if len(*host) == 0 { + log.Fatalf("'host' is required") + } + + if len(*user) == 0 { + log.Fatalf("'user' is required") + } + + // Dial TURN Server + turnServerAddr := fmt.Sprintf("%s:%d", *host, *port) + conn, err := net.Dial("tcp", turnServerAddr) + if err != nil { + log.Panicf("Failed to connect to TURN server: %s", err) + } + + cred := strings.SplitN(*user, "=", 2) + + // Start a new TURN Client and wrap our net.Conn in a STUNConn + // This allows us to simulate datagram based communication over a net.Conn + cfg := &turn.ClientConfig{ + STUNServerAddr: turnServerAddr, + TURNServerAddr: turnServerAddr, + Conn: turn.NewSTUNConn(conn), + Username: cred[0], + Password: cred[1], + Realm: *realm, + LoggerFactory: logging.NewDefaultLoggerFactory(), + } + + client, err := turn.NewClient(cfg) + if err != nil { + log.Panicf("Failed to create TURN client: %s", err) + } + defer client.Close() + + // Start listening on the conn provided. + err = client.Listen() + if err != nil { + log.Panicf("Failed to listen: %s", err) + } + + // Allocate a relay socket on the TURN server. On success, it + // will return a client.TCPAllocation which represents the remote + // socket. + allocation, err := client.AllocateTCP() + if err != nil { + log.Panicf("Failed to allocate: %s", err) + } + defer func() { + if closeErr := allocation.Close(); closeErr != nil { + log.Fatalf("Failed to close connection: %s", closeErr) + } + }() + + log.Printf("relayed-address=%s", allocation.Addr().String()) + + // Learn the peers relay address via signaling channel + addrCh := make(chan string, 5) + setupSignalingChannel(addrCh, *signaling, allocation.Addr().String()) + + // Get peer address + peerAddrString := <-addrCh + res := strings.Split(peerAddrString, ":") + peerIp := res[0] + peerPort, _ := strconv.Atoi(res[1]) + + log.Printf("Recieved peer address: %s", peerAddrString) + + buf := make([]byte, 4096) + peerAddr := net.TCPAddr{IP: net.ParseIP(peerIp), Port: peerPort} + var n int + if *signaling { + conn, err = allocation.Dial("tcp", peerAddrString) + if err != nil { + fmt.Println("Error connecting:", err) + } + conn.Write([]byte("hello!")) + n, err = conn.Read(buf) + if err != nil { + log.Println("Error reading from relay conn:", err) + } + conn.Close() + } else { + client.CreatePermission(&peerAddr) + conn, err := allocation.AcceptTCP() + if err != nil { + log.Println("Error accepting:", err) + } + log.Println("Accepted from:", conn.RemoteAddr()) + n, err = conn.Read(buf) + if err != nil { + log.Println("Error reading from relay conn:", err) + } + conn.Write([]byte("hello back!")) + conn.Close() + } + log.Println("Read message:", string(buf[:n])) +} diff --git a/internal/client/conn.go b/internal/client/conn.go index 8aeb742c..5065a1e4 100644 --- a/internal/client/conn.go +++ b/internal/client/conn.go @@ -35,8 +35,8 @@ type inboundData struct { from net.Addr } -// UDPConnObserver is an interface to UDPConn observer -type UDPConnObserver interface { +// ConnObserver is an interface to UDPConn observer +type ConnObserver interface { TURNServerAddr() net.Addr Username() stun.Username Realm() stun.Realm @@ -45,9 +45,9 @@ type UDPConnObserver interface { OnDeallocated(relayedAddr net.Addr) } -// UDPConnConfig is a set of configuration params use by NewUDPConn -type UDPConnConfig struct { - Observer UDPConnObserver +// ConnConfig is a set of configuration params use by NewUDPConn and NewTCPConn +type ConnConfig struct { + Observer ConnObserver RelayedAddr net.Addr Integrity stun.MessageIntegrity Nonce stun.Nonce @@ -55,39 +55,45 @@ type UDPConnConfig struct { Log logging.LeveledLogger } -// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. -// comatible with net.PacketConn and net.Conn -type UDPConn struct { - obs UDPConnObserver // read-only +type RelayConnContext struct { + obs ConnObserver // read-only relayedAddr net.Addr // read-only permMap *permissionMap // thread-safe - bindingMgr *bindingManager // thread-safe integrity stun.MessageIntegrity // read-only _nonce stun.Nonce // needs mutex x _lifetime time.Duration // needs mutex x - readCh chan *inboundData // thread-safe - closeCh chan struct{} // thread-safe - readTimer *time.Timer // thread-safe refreshAllocTimer *PeriodicTimer // thread-safe refreshPermsTimer *PeriodicTimer // thread-safe + readTimer *time.Timer // thread-safe mutex sync.RWMutex // thread-safe log logging.LeveledLogger // read-only } +// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. +// comatible with net.PacketConn and net.Conn +type UDPConn struct { + bindingMgr *bindingManager // thread-safe + readCh chan *inboundData // thread-safe + closeCh chan struct{} // thread-safe + RelayConnContext +} + // NewUDPConn creates a new instance of UDPConn -func NewUDPConn(config *UDPConnConfig) *UDPConn { +func NewUDPConn(config *ConnConfig) *UDPConn { c := &UDPConn{ - obs: config.Observer, - relayedAddr: config.RelayedAddr, - permMap: newPermissionMap(), - bindingMgr: newBindingManager(), - integrity: config.Integrity, - _nonce: config.Nonce, - _lifetime: config.Lifetime, - readCh: make(chan *inboundData, maxReadQueueSize), - closeCh: make(chan struct{}), - readTimer: time.NewTimer(time.Duration(math.MaxInt64)), - log: config.Log, + bindingMgr: newBindingManager(), + readCh: make(chan *inboundData, maxReadQueueSize), + closeCh: make(chan struct{}), + RelayConnContext: RelayConnContext{ + obs: config.Observer, + relayedAddr: config.RelayedAddr, + readTimer: time.NewTimer(time.Duration(math.MaxInt64)), + permMap: newPermissionMap(), + integrity: config.Integrity, + _nonce: config.Nonce, + _lifetime: config.Lifetime, + log: config.Log, + }, } c.log.Debugf("initial lifetime: %d seconds", int(c.lifetime().Seconds())) @@ -153,7 +159,7 @@ func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } } -func (c *UDPConn) createPermission(perm *permission, addr net.Addr) error { +func (c *RelayConnContext) createPermission(perm *permission, addr net.Addr) error { perm.mutex.Lock() defer perm.mutex.Unlock() @@ -365,7 +371,7 @@ func addr2PeerAddress(addr net.Addr) proto.PeerAddress { // CreatePermissions Issues a CreatePermission request for the supplied addresses // as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9 -func (c *UDPConn) CreatePermissions(addrs ...net.Addr) error { +func (c *RelayConnContext) CreatePermissions(addrs ...net.Addr) error { setters := []stun.Setter{ stun.TransactionID, stun.NewType(stun.MethodCreatePermission, stun.ClassRequest), @@ -433,7 +439,7 @@ func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) { return b.addr, true } -func (c *UDPConn) setNonceFromMsg(msg *stun.Message) { +func (c *RelayConnContext) setNonceFromMsg(msg *stun.Message) { // Update nonce var nonce stun.Nonce if err := nonce.GetFrom(msg); err == nil { @@ -444,7 +450,7 @@ func (c *UDPConn) setNonceFromMsg(msg *stun.Message) { } } -func (c *UDPConn) refreshAllocation(lifetime time.Duration, dontWait bool) error { +func (c *RelayConnContext) refreshAllocation(lifetime time.Duration, dontWait bool) error { msg, err := stun.Build( stun.TransactionID, stun.NewType(stun.MethodRefresh, stun.ClassRequest), @@ -496,7 +502,7 @@ func (c *UDPConn) refreshAllocation(lifetime time.Duration, dontWait bool) error return nil } -func (c *UDPConn) refreshPermissions() error { +func (c *RelayConnContext) refreshPermissions() error { addrs := c.permMap.addrs() if len(addrs) == 0 { c.log.Debug("no permission to refresh") @@ -562,7 +568,7 @@ func (c *UDPConn) sendChannelData(data []byte, chNum uint16) (int, error) { return len(data), nil } -func (c *UDPConn) onRefreshTimers(id int) { +func (c *RelayConnContext) onRefreshTimers(id int) { c.log.Debugf("refresh timer %d expired", id) switch id { case timerIDRefreshAlloc: @@ -593,14 +599,14 @@ func (c *UDPConn) onRefreshTimers(id int) { } } -func (c *UDPConn) nonce() stun.Nonce { +func (c *RelayConnContext) nonce() stun.Nonce { c.mutex.RLock() defer c.mutex.RUnlock() return c._nonce } -func (c *UDPConn) setNonce(nonce stun.Nonce) { +func (c *RelayConnContext) setNonce(nonce stun.Nonce) { c.mutex.Lock() defer c.mutex.Unlock() @@ -608,14 +614,14 @@ func (c *UDPConn) setNonce(nonce stun.Nonce) { c._nonce = nonce } -func (c *UDPConn) lifetime() time.Duration { +func (c *RelayConnContext) lifetime() time.Duration { c.mutex.RLock() defer c.mutex.RUnlock() return c._lifetime } -func (c *UDPConn) setLifetime(lifetime time.Duration) { +func (c *RelayConnContext) setLifetime(lifetime time.Duration) { c.mutex.Lock() defer c.mutex.Unlock() diff --git a/internal/client/conn_test.go b/internal/client/conn_test.go index c4ce6c35..bf4ddbab 100644 --- a/internal/client/conn_test.go +++ b/internal/client/conn_test.go @@ -64,7 +64,9 @@ func TestUDPConn(t *testing.T) { }) conn := UDPConn{ - obs: obs, + RelayConnContext: RelayConnContext{ + obs: obs, + }, bindingMgr: bm, } @@ -99,8 +101,10 @@ func TestUDPConn(t *testing.T) { binding.setState(bindingStateReady) conn := UDPConn{ - obs: obs, - permMap: pm, + RelayConnContext: RelayConnContext{ + obs: obs, + permMap: pm, + }, bindingMgr: bm, } diff --git a/internal/client/errors.go b/internal/client/errors.go index 7fc816fd..6488249f 100644 --- a/internal/client/errors.go +++ b/internal/client/errors.go @@ -8,7 +8,8 @@ var ( errFake = errors.New("fake error") errTryAgain = errors.New("try again") errClosed = errors.New("use of closed network connection") - errUDPAddrCast = errors.New("addr is not a net.UDPAddr") + errTCPAddrCast = errors.New("addr is not a tcp address") + errUDPAddrCast = errors.New("addr is not a udp address") errAlreadyClosed = errors.New("already closed") errDoubleLock = errors.New("try-lock is already locked") errTransactionClosed = errors.New("transaction closed") diff --git a/internal/client/permission.go b/internal/client/permission.go index 5546a22e..9ff97ae5 100644 --- a/internal/client/permission.go +++ b/internal/client/permission.go @@ -14,6 +14,7 @@ const ( ) type permission struct { + addr net.Addr st permState // thread-safe (atomic op) mutex sync.RWMutex // thread-safe } @@ -32,42 +33,35 @@ type permissionMap struct { mutex sync.RWMutex } +func addr2IPFingerprint(addr net.Addr) string { + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP.String() + case *net.TCPAddr: + return a.IP.String() + } + return "" // should never happen +} + func (m *permissionMap) insert(addr net.Addr, p *permission) bool { m.mutex.Lock() defer m.mutex.Unlock() - - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return false - } - - m.permMap[udpAddr.IP.String()] = p + p.addr = addr + m.permMap[addr2IPFingerprint(addr)] = p return true } func (m *permissionMap) find(addr net.Addr) (*permission, bool) { m.mutex.RLock() defer m.mutex.RUnlock() - - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return nil, false - } - - p, ok := m.permMap[udpAddr.IP.String()] + p, ok := m.permMap[addr2IPFingerprint(addr)] return p, ok } func (m *permissionMap) delete(addr net.Addr) { m.mutex.Lock() defer m.mutex.Unlock() - - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return - } - - delete(m.permMap, udpAddr.IP.String()) + delete(m.permMap, addr2IPFingerprint(addr)) } func (m *permissionMap) addrs() []net.Addr { @@ -75,10 +69,8 @@ func (m *permissionMap) addrs() []net.Addr { defer m.mutex.RUnlock() addrs := []net.Addr{} - for k := range m.permMap { - addrs = append(addrs, &net.UDPAddr{ - IP: net.ParseIP(k), - }) + for _, p := range m.permMap { + addrs = append(addrs, p.addr) } return addrs } diff --git a/internal/client/permission_test.go b/internal/client/permission_test.go index cbbf0875..9699f308 100644 --- a/internal/client/permission_test.go +++ b/internal/client/permission_test.go @@ -2,6 +2,8 @@ package client import ( "net" + "sort" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -25,15 +27,16 @@ func TestPermissionMap(t *testing.T) { perm1 := &permission{st: permStateIdle} perm2 := &permission{st: permStatePermitted} + perm3 := &permission{st: permStateIdle} udpAddr1, _ := net.ResolveUDPAddr("udp", "1.2.3.4:5000") udpAddr2, _ := net.ResolveUDPAddr("udp", "5.6.7.8:8888") - tcpAddr, _ := net.ResolveTCPAddr("tcp", "1.2.3.4:5000") + tcpAddr, _ := net.ResolveTCPAddr("tcp", "7.8.9.10:5000") assert.True(t, pm.insert(udpAddr1, perm1)) assert.Equal(t, 1, len(pm.permMap)) assert.True(t, pm.insert(udpAddr2, perm2)) assert.Equal(t, 2, len(pm.permMap)) - assert.False(t, pm.insert(tcpAddr, perm1)) - assert.Equal(t, 2, len(pm.permMap)) + assert.True(t, pm.insert(tcpAddr, perm3)) + assert.Equal(t, 3, len(pm.permMap)) p, ok := pm.find(udpAddr1) assert.True(t, ok) @@ -45,25 +48,35 @@ func TestPermissionMap(t *testing.T) { assert.Equal(t, perm2, p) assert.Equal(t, permStatePermitted, p.st) - _, ok = pm.find(tcpAddr) - assert.False(t, ok) + p, ok = pm.find(tcpAddr) + assert.True(t, ok) + assert.Equal(t, perm3, p) + assert.Equal(t, permStateIdle, p.st) addrs := pm.addrs() ips := []net.IP{} + for _, addr := range addrs { - udpAddr, err := net.ResolveUDPAddr(addr.Network(), addr.String()) - assert.NoError(t, err) - assert.Equal(t, 0, udpAddr.Port) - ips = append(ips, udpAddr.IP) + switch addr.(type) { + case *net.UDPAddr: + udpAddr, err := net.ResolveUDPAddr(addr.Network(), addr.String()) + assert.NoError(t, err) + ips = append(ips, udpAddr.IP) + case *net.TCPAddr: + tcpAddr, err := net.ResolveTCPAddr(addr.Network(), addr.String()) + assert.NoError(t, err) + ips = append(ips, tcpAddr.IP) + } } - assert.Equal(t, 2, len(ips)) - if ips[0].Equal(udpAddr1.IP) { - assert.True(t, ips[1].Equal(udpAddr2.IP)) - } else { - assert.True(t, ips[0].Equal(udpAddr2.IP)) - assert.True(t, ips[1].Equal(udpAddr1.IP)) - } + assert.Equal(t, 3, len(ips)) + sort.Slice(ips, func(i, j int) bool { + return strings.Compare(ips[i].String(), ips[j].String()) < 0 + }) + + assert.True(t, ips[0].Equal(udpAddr1.IP)) + assert.True(t, ips[1].Equal(udpAddr2.IP)) + assert.True(t, ips[2].Equal(tcpAddr.IP)) pm.delete(tcpAddr) assert.Equal(t, 2, len(pm.permMap)) diff --git a/internal/client/tcp_conn.go b/internal/client/tcp_conn.go new file mode 100644 index 00000000..ab561394 --- /dev/null +++ b/internal/client/tcp_conn.go @@ -0,0 +1,352 @@ +// Package client implements the API for a TURN client +package client + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "net" + "time" + + "github.com/pion/stun" + "github.com/pion/transport/v2" + "github.com/pion/turn/v2/internal/proto" +) + +var ( + errInvalidTURNFrame = errors.New("data is not a valid TURN frame, no STUN or ChannelData found") + errIncompleteTURNFrame = errors.New("data contains incomplete STUN or TURN frame") +) + +const ( + stunHeaderSize = 20 +) + +type TCPAllocation struct { + connAttemptCh chan *ConnectionAttempt + acceptTimer *time.Timer + RelayConnContext +} + +// TCPConn wraps a net.TCPConn and returns the allocations relayed +// transport address in response to TCPConn.LocalAddress() +type TCPConn struct { + *net.TCPConn + remoteAddress *net.TCPAddr + allocation *TCPAllocation + acceptDeadline time.Duration + ConnectionID proto.ConnectionID +} + +var _ transport.TCPListener = (*TCPAllocation)(nil) // Includes type check for net.Listener +var _ transport.TCPConn = (*TCPConn)(nil) // Includes type check for net.Conn +var _ transport.Dialer = (*TCPAllocation)(nil) + +type ConnectionAttempt struct { + from *net.TCPAddr + cid proto.ConnectionID +} + +func (c *TCPConn) LocalAddress() net.Addr { + return c.allocation.Addr() +} + +func (c *TCPConn) RemoteAddress() net.Addr { + return c.remoteAddress +} + +// NewTCPConn creates a new instance of TCPConn +func NewTCPAllocation(config *ConnConfig) *TCPAllocation { + a := &TCPAllocation{ + connAttemptCh: make(chan *ConnectionAttempt, 10), + acceptTimer: time.NewTimer(time.Duration(math.MaxInt64)), + RelayConnContext: RelayConnContext{ + obs: config.Observer, + relayedAddr: config.RelayedAddr, + permMap: newPermissionMap(), + integrity: config.Integrity, + _nonce: config.Nonce, + _lifetime: config.Lifetime, + log: config.Log, + }, + } + + a.log.Debugf("initial lifetime: %d seconds", int(a.lifetime().Seconds())) + + a.refreshAllocTimer = NewPeriodicTimer( + timerIDRefreshAlloc, + a.onRefreshTimers, + a.lifetime()/2, + ) + + a.refreshPermsTimer = NewPeriodicTimer( + timerIDRefreshPerms, + a.onRefreshTimers, + permRefreshInterval, + ) + + if a.refreshAllocTimer.Start() { + a.log.Debugf("refreshAllocTimer started") + } + if a.refreshPermsTimer.Start() { + a.log.Debugf("refreshPermsTimer started") + } + + return a +} + +func (a *TCPAllocation) Connect(peer net.Addr) (proto.ConnectionID, error) { + setters := []stun.Setter{ + stun.TransactionID, + stun.NewType(stun.MethodConnect, stun.ClassRequest), + addr2PeerAddress(peer), + a.obs.Username(), + a.obs.Realm(), + a.nonce(), + a.integrity, + stun.Fingerprint, + } + + msg, err := stun.Build(setters...) + if err != nil { + return 0, err + } + + a.log.Debugf("send connect request (peer=%v)", peer) + trRes, err := a.obs.PerformTransaction(msg, a.obs.TURNServerAddr(), false) + if err != nil { + return 0, err + } + res := trRes.Msg + + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + return 0, fmt.Errorf("%s (error %s)", res.Type, code) + } + return 0, fmt.Errorf("%s", res.Type) + } + + var cid proto.ConnectionID + if err := cid.GetFrom(res); err != nil { + return 0, err + } + + a.log.Debugf("connect request successful (cid=%v)", cid) + return cid, nil +} + +// Dial connects to the address on the named network. +func (a *TCPAllocation) Dial(network, address string) (net.Conn, error) { + conn, err := net.Dial(network, a.obs.TURNServerAddr().String()) + if err != nil { + return nil, err + } + + dataConn, err := a.DialWithConn(conn, network, address) + if err != nil { + conn.Close() + } + return dataConn, err +} + +func (a *TCPAllocation) DialWithConn(conn net.Conn, network, address string) (*TCPConn, error) { + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + // Check if we have a permission for the destination IP addr + perm, ok := a.permMap.find(addr) + if !ok { + perm = &permission{} + a.permMap.insert(addr, perm) + } + + for i := 0; i < maxRetryAttempts; i++ { + if err = a.createPermission(perm, addr); !errors.Is(err, errTryAgain) { + break + } + } + if err != nil { + return nil, err + } + + // Send connect request if haven't done so. + cid, err := a.Connect(addr) + if err != nil { + return nil, err + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, errTCPAddrCast + } + + dataConn := &TCPConn{ + TCPConn: tcpConn, + remoteAddress: addr, + allocation: a, + } + + if err := a.BindConnection(dataConn, cid); err != nil { + return nil, fmt.Errorf("failed to bind connection: %w", err) + } + + return dataConn, nil + +} + +// BindConnection associates the provided connection +func (a *TCPAllocation) BindConnection(dataConn *TCPConn, cid proto.ConnectionID) error { + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodConnectionBind, stun.ClassRequest), + cid, + a.obs.Username(), + a.obs.Realm(), + a.nonce(), + a.integrity, + stun.Fingerprint, + ) + if err != nil { + return err + } + + a.log.Debugf("send connectionBind request (cid=%v)", cid) + _, err = dataConn.Write(msg.Raw) + if err != nil { + return err + } + + // Read exactly one STUN message, + // any data after belongs to the user + b := make([]byte, stunHeaderSize) + n, err := dataConn.Read(b) + if n != stunHeaderSize { + return errIncompleteTURNFrame + } else if err != nil { + return err + } + if !stun.IsMessage(b) { + return errInvalidTURNFrame + } + + datagramSize := binary.BigEndian.Uint16(b[2:4]) + stunHeaderSize + raw := make([]byte, datagramSize) + copy(raw, b) + _, err = dataConn.Read(raw[stunHeaderSize:]) + if err != nil { + return err + } + res := &stun.Message{Raw: raw} + if err := res.Decode(); err != nil { + return fmt.Errorf("failed to decode STUN message: %s", err.Error()) + } + + switch res.Type.Class { + case stun.ClassErrorResponse: + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + return fmt.Errorf("%s (error %s)", res.Type, code) + } + return fmt.Errorf("%s", res.Type) + case stun.ClassSuccessResponse: + a.log.Debug("connectionBind request successful") + return nil + default: + return fmt.Errorf("unexpected STUN request message: %s", res.String()) + } +} + +// Accept waits for and returns the next connection to the listener. +func (a *TCPAllocation) Accept() (net.Conn, error) { + return a.AcceptTCP() +} + +// AcceptTCP accepts the next incoming call and returns the new connection. +func (a *TCPAllocation) AcceptTCP() (transport.TCPConn, error) { + addr, err := net.ResolveTCPAddr("tcp4", a.obs.TURNServerAddr().String()) + if err != nil { + return nil, err + } + + tcpConn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return nil, err + } + + dataConn, err := a.AcceptTCPWithConn(tcpConn) + if err != nil { + tcpConn.Close() + } + + return dataConn, err +} + +// AcceptTCP accepts the next incoming call and returns the new connection. +func (a *TCPAllocation) AcceptTCPWithConn(conn net.Conn) (transport.TCPConn, error) { + select { + case attempt := <-a.connAttemptCh: + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, errTCPAddrCast + } + + dataConn := &TCPConn{ + TCPConn: tcpConn, + ConnectionID: attempt.cid, + remoteAddress: attempt.from, + allocation: a, + } + + if err := a.BindConnection(dataConn, attempt.cid); err != nil { + return nil, fmt.Errorf("failed to bind connection: %w", err) + } + + return dataConn, nil + case <-a.acceptTimer.C: + return nil, &net.OpError{ + Op: "accept", + Net: a.Addr().Network(), + Addr: a.Addr(), + Err: newTimeoutError("i/o timeout"), + } + } +} + +func (a *TCPAllocation) SetDeadline(t time.Time) error { + var d time.Duration + if t == noDeadline() { + d = time.Duration(math.MaxInt64) + } else { + d = time.Until(t) + } + a.acceptTimer.Reset(d) + return nil +} + +// Close releases the allocation +// Any blocked Accept operations will be unblocked and return errors. +// Any opened connection via Dial/Accept will be closed. +func (a *TCPAllocation) Close() error { + a.refreshAllocTimer.Stop() + a.refreshPermsTimer.Stop() + + a.obs.OnDeallocated(a.relayedAddr) + return a.refreshAllocation(0, true /* dontWait=true */) +} + +func (a *TCPAllocation) Addr() net.Addr { + return a.relayedAddr +} + +func (a *TCPAllocation) HandleConnectionAttempt(from *net.TCPAddr, cid proto.ConnectionID) error { + a.connAttemptCh <- &ConnectionAttempt{ + from: from, + cid: cid, + } + return nil +} diff --git a/internal/client/tcp_conn_test.go b/internal/client/tcp_conn_test.go new file mode 100644 index 00000000..9f24bf8e --- /dev/null +++ b/internal/client/tcp_conn_test.go @@ -0,0 +1,116 @@ +package client + +import ( + "net" + "testing" + "time" + + "github.com/pion/logging" + "github.com/pion/stun" + "github.com/pion/turn/v2/internal/proto" + "github.com/stretchr/testify/assert" +) + +type dummyConnObserver struct { + turnServerAddr net.Addr + username stun.Username + realm stun.Realm + _writeTo func(data []byte, to net.Addr) (int, error) + _performTransaction func(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) + _onDeallocated func(relayedAddr net.Addr) +} + +func (obs *dummyConnObserver) TURNServerAddr() net.Addr { + return obs.turnServerAddr +} + +func (obs *dummyConnObserver) Username() stun.Username { + return obs.username +} + +func (obs *dummyConnObserver) Realm() stun.Realm { + return obs.realm +} + +func (obs *dummyConnObserver) WriteTo(data []byte, to net.Addr) (int, error) { + if obs._writeTo != nil { + return obs._writeTo(data, to) + } + return 0, nil +} + +func (obs *dummyConnObserver) PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) { + if obs._performTransaction != nil { + return obs._performTransaction(msg, to, dontWait) + } + return TransactionResult{}, nil +} + +func (obs *dummyConnObserver) OnDeallocated(relayedAddr net.Addr) { + if obs._onDeallocated != nil { + obs._onDeallocated(relayedAddr) + } +} + +func TestTCPConn(t *testing.T) { + t.Run("connect()", func(t *testing.T) { + var serverCid proto.ConnectionID = 4567 + obs := &dummyConnObserver{ + _performTransaction: func(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) { + if msg.Type.Class == stun.ClassRequest && msg.Type.Method == stun.MethodConnect { + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodConnect, stun.ClassSuccessResponse), + serverCid, + ) + assert.NoError(t, err) + return TransactionResult{Msg: msg}, nil + } + return TransactionResult{}, errFake + }, + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1234, + } + + pm := newPermissionMap() + assert.True(t, pm.insert(addr, &permission{ + st: permStatePermitted, + })) + + loggerFactory := logging.NewDefaultLoggerFactory() + log := loggerFactory.NewLogger("test") + alloc := TCPAllocation{ + RelayConnContext: RelayConnContext{ + obs: obs, + permMap: pm, + log: log, + }, + } + + cid, err := alloc.Connect(addr) + assert.Equal(t, serverCid, cid) + assert.NoError(t, err) + }) + + t.Run("SetDeadline()", func(t *testing.T) { + relayedAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:13478") + assert.NoError(t, err) + + loggerFactory := logging.NewDefaultLoggerFactory() + obs := &dummyConnObserver{} + alloc := NewTCPAllocation(&ConnConfig{ + Observer: obs, + Lifetime: time.Second, + Log: loggerFactory.NewLogger("test"), + RelayedAddr: relayedAddr, + }) + + alloc.SetDeadline(time.Now()) + cid, err := alloc.AcceptTCPWithConn(nil) + assert.Nil(t, cid) + assert.Contains(t, err.Error(), "i/o timeout") + }) +} diff --git a/internal/proto/reqtrans.go b/internal/proto/reqtrans.go index cc73a471..0cd9129e 100644 --- a/internal/proto/reqtrans.go +++ b/internal/proto/reqtrans.go @@ -10,12 +10,16 @@ import ( type Protocol byte const ( + // ProtoTCP is IANA assigned protocol number for TCP. + ProtoTCP Protocol = 6 // ProtoUDP is IANA assigned protocol number for UDP. ProtoUDP Protocol = 17 ) func (p Protocol) String() string { switch p { + case ProtoTCP: + return "TCP" case ProtoUDP: return "UDP" default: diff --git a/internal/server/errors.go b/internal/server/errors.go index 13f8ee1a..fc9a45b1 100644 --- a/internal/server/errors.go +++ b/internal/server/errors.go @@ -15,7 +15,7 @@ var ( errFailedToCreateSTUNPacket = errors.New("failed to create stun message from packet") errFailedToCreateChannelData = errors.New("failed to create channel data from packet") errRelayAlreadyAllocatedForFiveTuple = errors.New("relay already allocated for 5-TUPLE") - errRequestedTransportMustBeUDP = errors.New("RequestedTransport must be UDP") + errUnsupportedTransportProtocol = errors.New("RequestedTransport must be UDP or TCP") errNoDontFragmentSupport = errors.New("no support for DONT-FRAGMENT") errRequestWithReservationTokenAndEvenPort = errors.New("Request must not contain RESERVATION-TOKEN and EVEN-PORT") errNoAllocationFound = errors.New("no allocation found") diff --git a/internal/server/turn.go b/internal/server/turn.go index 4e7d25db..3ea75ecb 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -53,14 +53,14 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // attribute. If the REQUESTED-TRANSPORT attribute is not included // or is malformed, the server rejects the request with a 400 (Bad // Request) error. Otherwise, if the attribute is included but - // specifies a protocol other that UDP, the server rejects the + // specifies a protocol other that UDP/TCP, the server rejects the // request with a 442 (Unsupported Transport Protocol) error. var requestedTransport proto.RequestedTransport if err = requestedTransport.GetFrom(m); err != nil { return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) - } else if requestedTransport.Protocol != proto.ProtoUDP { + } else if requestedTransport.Protocol != proto.ProtoUDP && requestedTransport.Protocol != proto.ProtoTCP { msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeUnsupportedTransProto}) - return buildAndSendErr(r.Conn, r.SrcAddr, errRequestedTransportMustBeUDP, msg...) + return buildAndSendErr(r.Conn, r.SrcAddr, errUnsupportedTransportProtocol, msg...) } // 4. The request may contain a DONT-FRAGMENT attribute. If it does,