diff --git a/client.go b/client.go index 5e675fd3..0e1754f3 100644 --- a/client.go +++ b/client.go @@ -41,6 +41,7 @@ type ClientConfig struct { Password string Realm string Software string + Protocol string RTO time.Duration Conn net.PacketConn // Listening socket (net.PacketConn) LoggerFactory logging.LoggerFactory @@ -49,25 +50,27 @@ type ClientConfig struct { // Client is a STUN server client type Client struct { - conn net.PacketConn // read-only - stunServ net.Addr // read-only - turnServ net.Addr // read-only - stunServStr string // read-only, used for de-multiplexing - turnServStr string // read-only, used for de-multiplexing - username stun.Username // read-only - password string // read-only - realm stun.Realm // read-only - integrity stun.MessageIntegrity // read-only - software stun.Software // read-only - trMap *client.TransactionMap // thread-safe - rto time.Duration // read-only - relayedConn *client.UDPConn // protected by mutex *** - allocTryLock client.TryLock // thread-safe - listenTryLock client.TryLock // thread-safe - net transport.Net // read-only - mutex sync.RWMutex // thread-safe - mutexTrMap sync.Mutex // thread-safe - log logging.LeveledLogger // read-only + conn net.PacketConn // read-only // connection to the regular turn server. + stunServ net.Addr // read-only + turnServ net.Addr // read-only + stunServStr string // read-only, used for de-multiplexing + turnServStr string // read-only, used for de-multiplexing + username stun.Username // read-only + password string // read-only + realm stun.Realm // read-only + integrity stun.MessageIntegrity // read-only + software stun.Software // read-only + trMap *client.TransactionMap // thread-safe + rto time.Duration // read-only + relayedConn *client.UDPConn // protected by mutex *** + relayedTCPConn *client.TCPConn // protected by mutex *** + allocTryLock client.TryLock // thread-safe + listenTryLock client.TryLock // thread-safe + net transport.Net // read-only + mutex sync.RWMutex // thread-safe + mutexTrMap sync.Mutex // thread-safe + log logging.LeveledLogger // read-only + protocol proto.Protocol // read-only } // NewClient returns a new Client instance. listeningAddress is the address and port to listen on, default "0.0.0.0:0" @@ -93,24 +96,47 @@ func NewClient(config *ClientConfig) (*Client, error) { log.Warn("Virtual network is enabled") } + protocol := proto.ProtoUDP + if config.Protocol == "tcp" { + protocol = proto.ProtoTCP + } + var stunServ, turnServ net.Addr var stunServStr, turnServStr string if len(config.STUNServerAddr) > 0 { log.Debugf("resolving %s", config.STUNServerAddr) - stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr) - if err != nil { - return nil, err + switch protocol { + case proto.ProtoUDP: + stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr) + if err != nil { + return nil, err + } + stunServStr = stunServ.String() + case proto.ProtoTCP: + stunServ, err = config.Net.ResolveTCPAddr("tcp4", config.STUNServerAddr) + if err != nil { + return nil, err + } + stunServStr = stunServ.String() } - stunServStr = stunServ.String() log.Debugf("stunServ: %s", stunServStr) } if len(config.TURNServerAddr) > 0 { log.Debugf("resolving %s", config.TURNServerAddr) - turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr) - if err != nil { - return nil, err + switch protocol { + case proto.ProtoUDP: + turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr) + if err != nil { + return nil, err + } + turnServStr = turnServ.String() + case proto.ProtoTCP: + turnServ, err = config.Net.ResolveTCPAddr("tcp4", config.TURNServerAddr) + if err != nil { + return nil, err + } + turnServStr = turnServ.String() } - turnServStr = turnServ.String() log.Debugf("turnServ: %s", turnServStr) } @@ -133,6 +159,7 @@ func NewClient(config *ClientConfig) (*Client, error) { trMap: client.NewTransactionMap(), rto: rto, log: log, + protocol: protocol, } return c, nil @@ -238,42 +265,34 @@ func (c *Client) SendBindingRequest() (net.Addr, error) { return c.SendBindingRequestTo(c.stunServ) } -// Allocate sends a TURN allocation request to the given transport address -func (c *Client) Allocate() (net.PacketConn, error) { - if err := c.allocTryLock.Lock(); err != nil { - return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) - } - defer c.allocTryLock.Unlock() - - relayedConn := c.relayedUDPConn() - if relayedConn != nil { - return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String()) - } +func (c *Client) sendAllocateRequest() (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) { + var relayed proto.RelayedAddress + var lifetime proto.Lifetime + var nonce stun.Nonce msg, err := stun.Build( stun.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassRequest), - proto.RequestedTransport{Protocol: proto.ProtoUDP}, + proto.RequestedTransport{Protocol: c.protocol}, stun.Fingerprint, ) if err != nil { - return nil, err + return relayed, lifetime, nonce, err } trRes, err := c.PerformTransaction(msg, c.turnServ, false) if err != nil { - return nil, err + return relayed, lifetime, nonce, err } res := trRes.Msg // Anonymous allocate failed, trying to authenticate. - var nonce stun.Nonce if err = nonce.GetFrom(res); err != nil { - return nil, err + return relayed, lifetime, nonce, err } if err = c.realm.GetFrom(res); err != nil { - return nil, err + return relayed, lifetime, nonce, err } c.realm = append([]byte(nil), c.realm...) c.integrity = stun.NewLongTermIntegrity( @@ -283,7 +302,7 @@ func (c *Client) Allocate() (net.PacketConn, error) { msg, err = stun.Build( stun.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassRequest), - proto.RequestedTransport{Protocol: proto.ProtoUDP}, + proto.RequestedTransport{Protocol: c.protocol}, &c.username, &c.realm, &nonce, @@ -291,40 +310,93 @@ func (c *Client) Allocate() (net.PacketConn, error) { stun.Fingerprint, ) if err != nil { - return nil, err + return relayed, lifetime, nonce, err } trRes, err = c.PerformTransaction(msg, c.turnServ, false) if err != nil { - return nil, err + return relayed, lifetime, nonce, err } res = trRes.Msg if res.Type.Class == stun.ClassErrorResponse { var code stun.ErrorCodeAttribute if err = code.GetFrom(res); err == nil { - return nil, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113 + return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113 } - return nil, fmt.Errorf("%s", res.Type) //nolint:goerr113 + return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113 } // Getting relayed addresses from response. - var relayed proto.RelayedAddress if err := relayed.GetFrom(res); err != nil { + return relayed, lifetime, nonce, err + } + + // Getting lifetime from response + if err := lifetime.GetFrom(res); err != nil { + return relayed, lifetime, nonce, err + } + return relayed, lifetime, nonce, nil +} + +// Allocate sends a TURN allocation request to the given transport address +func (c *Client) Allocate() (net.PacketConn, error) { + if err := c.allocTryLock.Lock(); err != nil { + return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) + } + defer c.allocTryLock.Unlock() + + relayedConn := c.getRelayedUDPConn() + if relayedConn != nil { + return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String()) + } + + relayed, lifetime, nonce, err := c.sendAllocateRequest() + if err != nil { return nil, err } + relayedAddr := &net.UDPAddr{ IP: relayed.IP, Port: relayed.Port, } - // Getting lifetime from response - var lifetime proto.Lifetime - if err := lifetime.GetFrom(res); err != nil { + relayedConn = client.NewUDPConn(&client.ConnConfig{ + Observer: c, + RelayedAddr: relayedAddr, + Integrity: c.integrity, + Nonce: nonce, + Lifetime: lifetime.Duration, + Log: c.log, + }) + c.setRelayedUDPConn(relayedConn) + + return relayedConn, nil +} + +// Allocate TCP +func (c *Client) AllocateTCP() (*client.TCPConn, error) { + if err := c.allocTryLock.Lock(); err != nil { + return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) + } + defer c.allocTryLock.Unlock() + + relayedConn := c.getRelayedTCPConn() + if relayedConn != nil { + return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String()) + } + + relayed, lifetime, nonce, err := c.sendAllocateRequest() + if err != nil { return nil, err } - relayedConn = client.NewUDPConn(&client.UDPConnConfig{ + relayedAddr := &net.TCPAddr{ + IP: relayed.IP, + Port: relayed.Port, + } + + relayedConn = client.NewTCPConn(&client.ConnConfig{ Observer: c, RelayedAddr: relayedAddr, Integrity: c.integrity, @@ -333,7 +405,7 @@ func (c *Client) Allocate() (net.PacketConn, error) { Log: c.log, }) - c.setRelayedUDPConn(relayedConn) + c.setRelayedTCPConn(relayedConn) return relayedConn, nil } @@ -341,7 +413,13 @@ func (c *Client) Allocate() (net.PacketConn, error) { // CreatePermission Issues a CreatePermission request for the supplied addresses // as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9 func (c *Client) CreatePermission(addrs ...net.Addr) error { - return c.relayedUDPConn().CreatePermissions(addrs...) + switch c.protocol { + case proto.ProtoUDP: + return c.getRelayedUDPConn().CreatePermissions(addrs...) + case proto.ProtoTCP: + return c.getRelayedTCPConn().CreatePermissions(addrs...) + } + return nil } // PerformTransaction performs STUN transaction @@ -445,7 +523,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { } if msg.Type.Class == stun.ClassIndication { - if msg.Type.Method == stun.MethodData { + switch msg.Type.Method { + case stun.MethodData: var peerAddr proto.PeerAddress if err := peerAddr.GetFrom(msg); err != nil { return err @@ -462,13 +541,37 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { c.log.Debugf("data indication received from %s", from.String()) - relayedConn := c.relayedUDPConn() + relayedConn := c.getRelayedUDPConn() if relayedConn == nil { c.log.Debug("no relayed conn allocated") return nil // silently discard } - relayedConn.HandleInbound(data, from) + case stun.MethodConnectionAttempt: + var peerAddr proto.PeerAddress + if err := peerAddr.GetFrom(msg); err != nil { + return err + } + + from = &net.TCPAddr{ + IP: peerAddr.IP, + Port: peerAddr.Port, + } + + var cid proto.ConnectionID + if err := cid.GetFrom(msg); err != nil { + return err + } + + c.log.Debugf("connection attempt from %s", from.String()) + + relayedConn := c.getRelayedTCPConn() + if relayedConn == nil { + c.log.Debug("no relayed conn allocated") + return nil // silently discard + } + + relayedConn.HandleConnectionAttempt(data, from, cid) } return nil } @@ -514,7 +617,7 @@ func (c *Client) handleChannelData(data []byte) error { return err } - relayedConn := c.relayedUDPConn() + relayedConn := c.getRelayedUDPConn() if relayedConn == nil { c.log.Debug("no relayed conn allocated") return nil // silently discard @@ -573,9 +676,23 @@ func (c *Client) setRelayedUDPConn(conn *client.UDPConn) { c.relayedConn = conn } -func (c *Client) relayedUDPConn() *client.UDPConn { +func (c *Client) getRelayedUDPConn() *client.UDPConn { c.mutex.RLock() defer c.mutex.RUnlock() return c.relayedConn } + +func (c *Client) setRelayedTCPConn(conn *client.TCPConn) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.relayedTCPConn = conn +} + +func (c *Client) getRelayedTCPConn() *client.TCPConn { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.relayedTCPConn +} diff --git a/client_test.go b/client_test.go index e799f498..074830ba 100644 --- a/client_test.go +++ b/client_test.go @@ -187,3 +187,54 @@ func TestClientNonceExpiration(t *testing.T) { assert.NoError(t, conn.Close()) assert.NoError(t, server.Close()) } + +// Create a tcp-based allocation and verify allocation can be created +func TestTCPClient(t *testing.T) { + // Setup server + tcpListener, err := net.Listen("tcp4", "0.0.0.0:3478") + assert.NoError(t, err) + + server, err := NewServer(ServerConfig{ + AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + return GenerateAuthKey(username, realm, "pass"), true + }, + ListenerConfigs: []ListenerConfig{ + { + Listener: tcpListener, + RelayAddressGenerator: &RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP("127.0.0.1"), + Address: "0.0.0.0", + }, + }, + }, + Realm: "pion.ly", + }) + assert.NoError(t, err) + + // Setup clients + conn, err := net.Dial("tcp", "127.0.0.1:3478") + assert.NoError(t, err) + + client, err := NewClient(&ClientConfig{ + Conn: NewSTUNConn(conn), + STUNServerAddr: "127.0.0.1:3478", + TURNServerAddr: "127.0.0.1:3478", + Username: "foo", + Password: "pass", + Protocol: "tcp", + }) + assert.NoError(t, err) + assert.NoError(t, client.Listen()) + + allocation, err := client.AllocateTCP() + assert.NoError(t, err) + + // TODO: Implement server side handling of Connect and ConnectionBind + // _, err = allocation.Dial(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}) + // assert.NoError(t, err) + + // Shutdown + assert.NoError(t, allocation.Close()) + assert.NoError(t, conn.Close()) + assert.NoError(t, server.Close()) +} diff --git a/errors.go b/errors.go index 12d5e0e8..b4fa640e 100644 --- a/errors.go +++ b/errors.go @@ -25,4 +25,5 @@ var ( errNonSTUNMessage = errors.New("non-STUN message from STUN server") errFailedToDecodeSTUN = errors.New("failed to decode STUN message") errUnexpectedSTUNRequestMessage = errors.New("unexpected STUN request message") + errInvalidProtocol = errors.New("unexpected STUN request message") ) diff --git a/examples/turn-client/tcp/main.go b/examples/turn-client/tcp/main.go index 041d1ab4..6e99bb84 100644 --- a/examples/turn-client/tcp/main.go +++ b/examples/turn-client/tcp/main.go @@ -2,23 +2,67 @@ package main import ( + "bufio" "flag" "fmt" "log" "net" + "strconv" "strings" - "time" "github.com/pion/logging" "github.com/pion/turn/v2" ) +func handleSignaling(server bool, addrCh chan string, relayAddr string) { + addr := "127.0.0.1:5000" + if server { + go func() { + listen, err := net.Listen("tcp", addr) + if err != nil { + log.Panicf("Failed to create signaling server: %s", err) + } + defer listen.Close() + for { + conn, err := listen.Accept() + if err != nil { + log.Panicf("Failed to accept: %s", err) + } + go func() { + message, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + log.Panicf("Failed to read relayAddr: %s", err) + } + addrCh <- message[:len(message)-1] + }() + if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", relayAddr))); err != nil { + log.Panicf("Failed to write relayAddr: %s", err) + } + } + }() + } else { + conn, err := net.Dial("tcp", addr) + if err != nil { + log.Panicf("Error dialing: %s", err) + } + message, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + log.Panicf("Failed to read relayAddr: %s", err) + } + addrCh <- message[:len(message)-1] + if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", relayAddr))); err != nil { + log.Panicf("Failed to write relayAddr: %s", err) + } + } +} + func main() { host := flag.String("host", "", "TURN Server name.") port := flag.Int("port", 3478, "Listening port.") user := flag.String("user", "", "A pair of username and password (e.g. \"user=pass\")") realm := flag.String("realm", "pion.ly", "Realm (defaults to \"pion.ly\")") - ping := flag.Bool("ping", false, "Run ping test") + server := flag.Bool("server", false, "Whether to start signaling server") + flag.Parse() if len(*host) == 0 { @@ -47,6 +91,7 @@ func main() { Username: cred[0], Password: cred[1], Realm: *realm, + Protocol: "tcp", LoggerFactory: logging.NewDefaultLoggerFactory(), } @@ -63,9 +108,9 @@ func main() { } // Allocate a relay socket on the TURN server. On success, it - // will return a net.PacketConn which represents the remote + // will return a client.RelayConn which represents the remote // socket. - relayConn, err := client.Allocate() + relayConn, err := client.AllocateTCP() if err != nil { log.Panicf("Failed to allocate: %s", err) } @@ -75,94 +120,46 @@ func main() { } }() - // The relayConn's local address is actually the transport - // address assigned on the TURN server. log.Printf("relayed-address=%s", relayConn.LocalAddr().String()) - // If you provided `-ping`, perform a ping test against the - // relayConn we have just allocated. - if *ping { - err = doPingTest(client, relayConn) - if err != nil { - log.Panicf("Failed to ping: %s", err) - } - } -} - -func doPingTest(client *turn.Client, relayConn net.PacketConn) error { - // Send BindingRequest to learn our external IP - mappedAddr, err := client.SendBindingRequest() - if err != nil { - return err - } - - // Set up pinger socket (pingerConn) - pingerConn, err := net.ListenPacket("udp4", "0.0.0.0:0") - if err != nil { - log.Panicf("Failed to listen: %s", err) - } - defer func() { - if closeErr := pingerConn.Close(); closeErr != nil { - log.Panicf("Failed to close connection: %s", closeErr) - } - }() - - // Punch a UDP hole for the relayConn by sending a data to the mappedAddr. - // This will trigger a TURN client to generate a permission request to the - // TURN server. After this, packets from the IP address will be accepted by - // the TURN server. - _, err = relayConn.WriteTo([]byte("Hello"), mappedAddr) - if err != nil { - return err - } - - // Start read-loop on pingerConn - go func() { - buf := make([]byte, 1600) - for { - n, from, pingerErr := pingerConn.ReadFrom(buf) - if pingerErr != nil { - break + // Learn the peers relay address via signaling channel + addrCh := make(chan string, 5) + handleSignaling(*server, addrCh, relayConn.LocalAddr().String()) + + for { + peerAddrString := <-addrCh + res := strings.Split(peerAddrString, ":") + ip := res[0] + port, _ := strconv.Atoi(res[1]) + + log.Printf("Recieved peer address: %s", peerAddrString) + + buf := make([]byte, 4096) + peerAddr := net.TCPAddr{IP: net.ParseIP(ip), Port: port} + var conn net.Conn + var n int + if *server { + conn, err = relayConn.Dial(&peerAddr) + if err != nil { + fmt.Println("Error connecting:", err) } - - msg := string(buf[:n]) - if sentAt, pingerErr := time.Parse(time.RFC3339Nano, msg); pingerErr == nil { - rtt := time.Since(sentAt) - log.Printf("%d bytes from from %s time=%d ms\n", n, from.String(), int(rtt.Seconds()*1000)) - } - } - }() - - // Start read-loop on relayConn - go func() { - buf := make([]byte, 1600) - for { - n, from, readerErr := relayConn.ReadFrom(buf) - if readerErr != nil { + conn.Write([]byte("hello!")) + n, err = conn.Read(buf) + if err != nil { + log.Println("Error reading from relay conn:", err) break } - - // Echo back - if _, readerErr = relayConn.WriteTo(buf[:n], from); readerErr != nil { + } else { + relayConn.CreatePermissions(&peerAddr) + conn = relayConn.Accept() + n, err = conn.Read(buf) + if err != nil { + log.Println("Error reading from relay conn:", err) break } + conn.Write([]byte("hello back!")) } - }() - - time.Sleep(500 * time.Millisecond) - - // Send 10 packets from relayConn to the echo server - for i := 0; i < 10; i++ { - msg := time.Now().Format(time.RFC3339Nano) - _, err = pingerConn.WriteTo([]byte(msg), relayConn.LocalAddr()) - if err != nil { - return err - } - - // For simplicity, this example does not wait for the pong (reply). - // Instead, sleep 1 second. - time.Sleep(time.Second) + log.Println("Read message:", string(buf[:n])) + conn.Close() } - - return nil } diff --git a/internal/client/conn.go b/internal/client/conn.go index 8aeb742c..0311637f 100644 --- a/internal/client/conn.go +++ b/internal/client/conn.go @@ -35,8 +35,8 @@ type inboundData struct { from net.Addr } -// UDPConnObserver is an interface to UDPConn observer -type UDPConnObserver interface { +// ConnObserver is an interface to UDPConn observer +type ConnObserver interface { TURNServerAddr() net.Addr Username() stun.Username Realm() stun.Realm @@ -45,9 +45,9 @@ type UDPConnObserver interface { OnDeallocated(relayedAddr net.Addr) } -// UDPConnConfig is a set of configuration params use by NewUDPConn -type UDPConnConfig struct { - Observer UDPConnObserver +// ConnConfig is a set of configuration params use by NewUDPConn +type ConnConfig struct { + Observer ConnObserver RelayedAddr net.Addr Integrity stun.MessageIntegrity Nonce stun.Nonce @@ -55,39 +55,45 @@ type UDPConnConfig struct { Log logging.LeveledLogger } -// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. -// comatible with net.PacketConn and net.Conn -type UDPConn struct { - obs UDPConnObserver // read-only +type RelayConn struct { + obs ConnObserver // read-only relayedAddr net.Addr // read-only permMap *permissionMap // thread-safe - bindingMgr *bindingManager // thread-safe integrity stun.MessageIntegrity // read-only _nonce stun.Nonce // needs mutex x _lifetime time.Duration // needs mutex x - readCh chan *inboundData // thread-safe - closeCh chan struct{} // thread-safe - readTimer *time.Timer // thread-safe refreshAllocTimer *PeriodicTimer // thread-safe refreshPermsTimer *PeriodicTimer // thread-safe mutex sync.RWMutex // thread-safe log logging.LeveledLogger // read-only } +// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. +// comatible with net.PacketConn and net.Conn +type UDPConn struct { + bindingMgr *bindingManager // thread-safe + readCh chan *inboundData // thread-safe + closeCh chan struct{} // thread-safe + readTimer *time.Timer // thread-safe + RelayConn +} + // NewUDPConn creates a new instance of UDPConn -func NewUDPConn(config *UDPConnConfig) *UDPConn { +func NewUDPConn(config *ConnConfig) *UDPConn { c := &UDPConn{ - obs: config.Observer, - relayedAddr: config.RelayedAddr, - permMap: newPermissionMap(), - bindingMgr: newBindingManager(), - integrity: config.Integrity, - _nonce: config.Nonce, - _lifetime: config.Lifetime, - readCh: make(chan *inboundData, maxReadQueueSize), - closeCh: make(chan struct{}), - readTimer: time.NewTimer(time.Duration(math.MaxInt64)), - log: config.Log, + bindingMgr: newBindingManager(), + readCh: make(chan *inboundData, maxReadQueueSize), + closeCh: make(chan struct{}), + readTimer: time.NewTimer(time.Duration(math.MaxInt64)), + RelayConn: RelayConn{ + obs: config.Observer, + relayedAddr: config.RelayedAddr, + permMap: newPermissionMap(), + integrity: config.Integrity, + _nonce: config.Nonce, + _lifetime: config.Lifetime, + log: config.Log, + }, } c.log.Debugf("initial lifetime: %d seconds", int(c.lifetime().Seconds())) @@ -153,7 +159,7 @@ func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } } -func (c *UDPConn) createPermission(perm *permission, addr net.Addr) error { +func (c *RelayConn) createPermission(perm *permission, addr net.Addr) error { perm.mutex.Lock() defer perm.mutex.Unlock() @@ -302,7 +308,7 @@ func (c *UDPConn) Close() error { } // LocalAddr returns the local network address. -func (c *UDPConn) LocalAddr() net.Addr { +func (c *RelayConn) LocalAddr() net.Addr { return c.relayedAddr } @@ -365,7 +371,7 @@ func addr2PeerAddress(addr net.Addr) proto.PeerAddress { // CreatePermissions Issues a CreatePermission request for the supplied addresses // as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9 -func (c *UDPConn) CreatePermissions(addrs ...net.Addr) error { +func (c *RelayConn) CreatePermissions(addrs ...net.Addr) error { setters := []stun.Setter{ stun.TransactionID, stun.NewType(stun.MethodCreatePermission, stun.ClassRequest), @@ -433,7 +439,7 @@ func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) { return b.addr, true } -func (c *UDPConn) setNonceFromMsg(msg *stun.Message) { +func (c *RelayConn) setNonceFromMsg(msg *stun.Message) { // Update nonce var nonce stun.Nonce if err := nonce.GetFrom(msg); err == nil { @@ -444,7 +450,7 @@ func (c *UDPConn) setNonceFromMsg(msg *stun.Message) { } } -func (c *UDPConn) refreshAllocation(lifetime time.Duration, dontWait bool) error { +func (c *RelayConn) refreshAllocation(lifetime time.Duration, dontWait bool) error { msg, err := stun.Build( stun.TransactionID, stun.NewType(stun.MethodRefresh, stun.ClassRequest), @@ -496,7 +502,7 @@ func (c *UDPConn) refreshAllocation(lifetime time.Duration, dontWait bool) error return nil } -func (c *UDPConn) refreshPermissions() error { +func (c *RelayConn) refreshPermissions() error { addrs := c.permMap.addrs() if len(addrs) == 0 { c.log.Debug("no permission to refresh") @@ -562,7 +568,7 @@ func (c *UDPConn) sendChannelData(data []byte, chNum uint16) (int, error) { return len(data), nil } -func (c *UDPConn) onRefreshTimers(id int) { +func (c *RelayConn) onRefreshTimers(id int) { c.log.Debugf("refresh timer %d expired", id) switch id { case timerIDRefreshAlloc: @@ -593,14 +599,14 @@ func (c *UDPConn) onRefreshTimers(id int) { } } -func (c *UDPConn) nonce() stun.Nonce { +func (c *RelayConn) nonce() stun.Nonce { c.mutex.RLock() defer c.mutex.RUnlock() return c._nonce } -func (c *UDPConn) setNonce(nonce stun.Nonce) { +func (c *RelayConn) setNonce(nonce stun.Nonce) { c.mutex.Lock() defer c.mutex.Unlock() @@ -608,14 +614,14 @@ func (c *UDPConn) setNonce(nonce stun.Nonce) { c._nonce = nonce } -func (c *UDPConn) lifetime() time.Duration { +func (c *RelayConn) lifetime() time.Duration { c.mutex.RLock() defer c.mutex.RUnlock() return c._lifetime } -func (c *UDPConn) setLifetime(lifetime time.Duration) { +func (c *RelayConn) setLifetime(lifetime time.Duration) { c.mutex.Lock() defer c.mutex.Unlock() diff --git a/internal/client/conn_test.go b/internal/client/conn_test.go index c4ce6c35..36142161 100644 --- a/internal/client/conn_test.go +++ b/internal/client/conn_test.go @@ -64,7 +64,9 @@ func TestUDPConn(t *testing.T) { }) conn := UDPConn{ - obs: obs, + RelayConn: RelayConn{ + obs: obs, + }, bindingMgr: bm, } @@ -99,8 +101,10 @@ func TestUDPConn(t *testing.T) { binding.setState(bindingStateReady) conn := UDPConn{ - obs: obs, - permMap: pm, + RelayConn: RelayConn{ + obs: obs, + permMap: pm, + }, bindingMgr: bm, } diff --git a/internal/client/errors.go b/internal/client/errors.go index 7fc816fd..c0f32b1e 100644 --- a/internal/client/errors.go +++ b/internal/client/errors.go @@ -8,6 +8,7 @@ var ( errFake = errors.New("fake error") errTryAgain = errors.New("try again") errClosed = errors.New("use of closed network connection") + errTCPAddrCast = errors.New("addr is not a net.TCPAddr") errUDPAddrCast = errors.New("addr is not a net.UDPAddr") errAlreadyClosed = errors.New("already closed") errDoubleLock = errors.New("try-lock is already locked") diff --git a/internal/client/permission.go b/internal/client/permission.go index 5546a22e..9ff97ae5 100644 --- a/internal/client/permission.go +++ b/internal/client/permission.go @@ -14,6 +14,7 @@ const ( ) type permission struct { + addr net.Addr st permState // thread-safe (atomic op) mutex sync.RWMutex // thread-safe } @@ -32,42 +33,35 @@ type permissionMap struct { mutex sync.RWMutex } +func addr2IPFingerprint(addr net.Addr) string { + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP.String() + case *net.TCPAddr: + return a.IP.String() + } + return "" // should never happen +} + func (m *permissionMap) insert(addr net.Addr, p *permission) bool { m.mutex.Lock() defer m.mutex.Unlock() - - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return false - } - - m.permMap[udpAddr.IP.String()] = p + p.addr = addr + m.permMap[addr2IPFingerprint(addr)] = p return true } func (m *permissionMap) find(addr net.Addr) (*permission, bool) { m.mutex.RLock() defer m.mutex.RUnlock() - - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return nil, false - } - - p, ok := m.permMap[udpAddr.IP.String()] + p, ok := m.permMap[addr2IPFingerprint(addr)] return p, ok } func (m *permissionMap) delete(addr net.Addr) { m.mutex.Lock() defer m.mutex.Unlock() - - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return - } - - delete(m.permMap, udpAddr.IP.String()) + delete(m.permMap, addr2IPFingerprint(addr)) } func (m *permissionMap) addrs() []net.Addr { @@ -75,10 +69,8 @@ func (m *permissionMap) addrs() []net.Addr { defer m.mutex.RUnlock() addrs := []net.Addr{} - for k := range m.permMap { - addrs = append(addrs, &net.UDPAddr{ - IP: net.ParseIP(k), - }) + for _, p := range m.permMap { + addrs = append(addrs, p.addr) } return addrs } diff --git a/internal/client/permission_test.go b/internal/client/permission_test.go index cbbf0875..9699f308 100644 --- a/internal/client/permission_test.go +++ b/internal/client/permission_test.go @@ -2,6 +2,8 @@ package client import ( "net" + "sort" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -25,15 +27,16 @@ func TestPermissionMap(t *testing.T) { perm1 := &permission{st: permStateIdle} perm2 := &permission{st: permStatePermitted} + perm3 := &permission{st: permStateIdle} udpAddr1, _ := net.ResolveUDPAddr("udp", "1.2.3.4:5000") udpAddr2, _ := net.ResolveUDPAddr("udp", "5.6.7.8:8888") - tcpAddr, _ := net.ResolveTCPAddr("tcp", "1.2.3.4:5000") + tcpAddr, _ := net.ResolveTCPAddr("tcp", "7.8.9.10:5000") assert.True(t, pm.insert(udpAddr1, perm1)) assert.Equal(t, 1, len(pm.permMap)) assert.True(t, pm.insert(udpAddr2, perm2)) assert.Equal(t, 2, len(pm.permMap)) - assert.False(t, pm.insert(tcpAddr, perm1)) - assert.Equal(t, 2, len(pm.permMap)) + assert.True(t, pm.insert(tcpAddr, perm3)) + assert.Equal(t, 3, len(pm.permMap)) p, ok := pm.find(udpAddr1) assert.True(t, ok) @@ -45,25 +48,35 @@ func TestPermissionMap(t *testing.T) { assert.Equal(t, perm2, p) assert.Equal(t, permStatePermitted, p.st) - _, ok = pm.find(tcpAddr) - assert.False(t, ok) + p, ok = pm.find(tcpAddr) + assert.True(t, ok) + assert.Equal(t, perm3, p) + assert.Equal(t, permStateIdle, p.st) addrs := pm.addrs() ips := []net.IP{} + for _, addr := range addrs { - udpAddr, err := net.ResolveUDPAddr(addr.Network(), addr.String()) - assert.NoError(t, err) - assert.Equal(t, 0, udpAddr.Port) - ips = append(ips, udpAddr.IP) + switch addr.(type) { + case *net.UDPAddr: + udpAddr, err := net.ResolveUDPAddr(addr.Network(), addr.String()) + assert.NoError(t, err) + ips = append(ips, udpAddr.IP) + case *net.TCPAddr: + tcpAddr, err := net.ResolveTCPAddr(addr.Network(), addr.String()) + assert.NoError(t, err) + ips = append(ips, tcpAddr.IP) + } } - assert.Equal(t, 2, len(ips)) - if ips[0].Equal(udpAddr1.IP) { - assert.True(t, ips[1].Equal(udpAddr2.IP)) - } else { - assert.True(t, ips[0].Equal(udpAddr2.IP)) - assert.True(t, ips[1].Equal(udpAddr1.IP)) - } + assert.Equal(t, 3, len(ips)) + sort.Slice(ips, func(i, j int) bool { + return strings.Compare(ips[i].String(), ips[j].String()) < 0 + }) + + assert.True(t, ips[0].Equal(udpAddr1.IP)) + assert.True(t, ips[1].Equal(udpAddr2.IP)) + assert.True(t, ips[2].Equal(tcpAddr.IP)) pm.delete(tcpAddr) assert.Equal(t, 2, len(pm.permMap)) diff --git a/internal/client/tcp_conn.go b/internal/client/tcp_conn.go new file mode 100644 index 00000000..baf80eeb --- /dev/null +++ b/internal/client/tcp_conn.go @@ -0,0 +1,256 @@ +// Package client implements the API for a TURN client +package client + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + + "github.com/pion/stun" + "github.com/pion/turn/v2/internal/proto" +) + +// TCPConn is the implementation of the Conn and PacketConn interfaces for TCP network connections. +type TCPConn struct { + connCh chan net.Conn + RelayConn +} + +var ( + errInvalidTURNFrame = errors.New("data is not a valid TURN frame, no STUN or ChannelData found") + errIncompleteTURNFrame = errors.New("data contains incomplete STUN or TURN frame") +) + +const ( + stunHeaderSize = 20 +) + +// NewTCPConn creates a new instance of TCPConn +func NewTCPConn(config *ConnConfig) *TCPConn { + c := &TCPConn{ + connCh: make(chan net.Conn, 10), + RelayConn: RelayConn{ + obs: config.Observer, + relayedAddr: config.RelayedAddr, + permMap: newPermissionMap(), + integrity: config.Integrity, + _nonce: config.Nonce, + _lifetime: config.Lifetime, + log: config.Log, + }, + } + + c.log.Debugf("initial lifetime: %d seconds", int(c.lifetime().Seconds())) + + c.refreshAllocTimer = NewPeriodicTimer( + timerIDRefreshAlloc, + c.onRefreshTimers, + c.lifetime()/2, + ) + + c.refreshPermsTimer = NewPeriodicTimer( + timerIDRefreshPerms, + c.onRefreshTimers, + permRefreshInterval, + ) + + if c.refreshAllocTimer.Start() { + c.log.Debugf("refreshAllocTimer started") + } + if c.refreshPermsTimer.Start() { + c.log.Debugf("refreshPermsTimer started") + } + + return c +} + +func (c *TCPConn) Accept() net.Conn { + return <-c.connCh +} + +func (c *TCPConn) connect(peer net.Addr) (proto.ConnectionID, error) { + setters := []stun.Setter{ + stun.TransactionID, + stun.NewType(stun.MethodConnect, stun.ClassRequest), + addr2PeerAddress(peer), + c.obs.Username(), + c.obs.Realm(), + c.nonce(), + c.integrity, + stun.Fingerprint, + } + + msg, err := stun.Build(setters...) + if err != nil { + return 0, err + } + + c.log.Debugf("send connect request (peer=%v)", peer) + trRes, err := c.obs.PerformTransaction(msg, c.obs.TURNServerAddr(), false) + if err != nil { + return 0, err + } + res := trRes.Msg + + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + return 0, fmt.Errorf("%s (error %s)", res.Type, code) + } + return 0, fmt.Errorf("%s", res.Type) + } + + var cid proto.ConnectionID + if err := cid.GetFrom(res); err != nil { + return 0, err + } + + c.log.Debugf("connect request successful (cid=%v)", cid) + return cid, nil +} + +func (c *TCPConn) connectionBind(dataConn net.Conn, cid proto.ConnectionID) error { + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodConnectionBind, stun.ClassRequest), + cid, + c.obs.Username(), + c.obs.Realm(), + c.nonce(), + c.integrity, + stun.Fingerprint, + ) + if err != nil { + return err + } + + c.log.Debugf("send connectionBind request (cid=%v)", cid) + _, err = dataConn.Write(msg.Raw) + if err != nil { + return err + } + + // read exactly one STUN message, + // any data after belongs to the user + b := make([]byte, stunHeaderSize) + n, err := dataConn.Read(b) + if n != stunHeaderSize { + return errIncompleteTURNFrame + } else if err != nil { + return err + } + if !stun.IsMessage(b) { + return errInvalidTURNFrame + } + + datagramSize := binary.BigEndian.Uint16(b[2:4]) + stunHeaderSize + raw := make([]byte, datagramSize) + copy(raw, b) + _, err = dataConn.Read(raw[stunHeaderSize:]) + if err != nil { + return err + } + res := &stun.Message{Raw: raw} + if err := res.Decode(); err != nil { + return fmt.Errorf("failed to decode STUN message: %s", err.Error()) + } + + switch res.Type.Class { + case stun.ClassErrorResponse: + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + return fmt.Errorf("%s (error %s)", res.Type, code) + } + return fmt.Errorf("%s", res.Type) + case stun.ClassSuccessResponse: + c.log.Debug("connectionBind request successful") + return nil + default: + return fmt.Errorf("unexpected STUN request message: %s", res.String()) + } +} + +func (c *TCPConn) Dial(addr net.Addr) (net.Conn, error) { + var err error + _, ok := addr.(*net.TCPAddr) + if !ok { + return nil, errTCPAddrCast + } + + // check if we have a permission for the destination IP addr + perm, ok := c.permMap.find(addr) + if !ok { + perm = &permission{} + c.permMap.insert(addr, perm) + } + + for i := 0; i < maxRetryAttempts; i++ { + if err = c.createPermission(perm, addr); !errors.Is(err, errTryAgain) { + break + } + } + if err != nil { + return nil, err + } + + // Send connect request if haven't done so. + cid, err := c.connect(addr) + if err != nil { + return nil, err + } + + // create a data connection if doesn't exist, and send connbind over it. + conn, err := net.Dial("tcp", c.obs.TURNServerAddr().String()) + if err != nil { + return nil, err + } + + if err = c.connectionBind(conn, cid); err != nil { + return nil, err + } + + return conn, nil +} + +// Close closes the connection. +// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors. +func (c *TCPConn) Close() error { + c.refreshAllocTimer.Stop() + c.refreshPermsTimer.Stop() + + c.obs.OnDeallocated(c.relayedAddr) + return c.refreshAllocation(0, true /* dontWait=true */) +} + +func (c *TCPConn) HandleConnectionAttempt(data []byte, from net.Addr, cid proto.ConnectionID) error { + // set up + + // If the client + // wishes to accept this connection, it MUST initiate a new TCP + // connection to the server, utilizing the same destination transport + // address to which the control connection was established. This + // connection MUST be made using a different local transport address. + + conn, err := net.Dial("tcp", c.obs.TURNServerAddr().String()) + if err != nil { + return err + } + + // Authentication of the client by the server MUST use the same method + // and credentials as for the control connection. Once established, the + // client MUST send a ConnectionBind request over the new connection. + // That request MUST include the CONNECTION-ID attribute, echoed from + // the ConnectionAttempt indication. + + if err = c.connectionBind(conn, cid); err != nil { + return err + } + + // When a response to the + // ConnectionBind request is received, if it is a success, the TCP + // connection on which it was sent is called the client data connection + // corresponding to the peer. + c.connCh <- conn + return nil +} diff --git a/internal/client/tcp_conn_test.go b/internal/client/tcp_conn_test.go new file mode 100644 index 00000000..db250609 --- /dev/null +++ b/internal/client/tcp_conn_test.go @@ -0,0 +1,96 @@ +package client + +import ( + "net" + "testing" + + "github.com/pion/logging" + "github.com/pion/stun" + "github.com/pion/turn/v2/internal/proto" + "github.com/stretchr/testify/assert" +) + +type dummyConnObserver struct { + turnServerAddr net.Addr + username stun.Username + realm stun.Realm + _writeTo func(data []byte, to net.Addr) (int, error) + _performTransaction func(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) + _onDeallocated func(relayedAddr net.Addr) +} + +func (obs *dummyConnObserver) TURNServerAddr() net.Addr { + return obs.turnServerAddr +} + +func (obs *dummyConnObserver) Username() stun.Username { + return obs.username +} + +func (obs *dummyConnObserver) Realm() stun.Realm { + return obs.realm +} + +func (obs *dummyConnObserver) WriteTo(data []byte, to net.Addr) (int, error) { + if obs._writeTo != nil { + return obs._writeTo(data, to) + } + return 0, nil +} + +func (obs *dummyConnObserver) PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) { + if obs._performTransaction != nil { + return obs._performTransaction(msg, to, dontWait) + } + return TransactionResult{}, nil +} + +func (obs *dummyConnObserver) OnDeallocated(relayedAddr net.Addr) { + if obs._onDeallocated != nil { + obs._onDeallocated(relayedAddr) + } +} + +func TestTCPConn(t *testing.T) { + t.Run("connect()", func(t *testing.T) { + var serverCid proto.ConnectionID = 4567 + obs := &dummyConnObserver{ + _performTransaction: func(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) { + if msg.Type.Class == stun.ClassRequest && msg.Type.Method == stun.MethodConnect { + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodConnect, stun.ClassSuccessResponse), + serverCid, + ) + assert.NoError(t, err) + return TransactionResult{Msg: msg}, nil + } + return TransactionResult{}, errFake + }, + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1234, + } + + pm := newPermissionMap() + assert.True(t, pm.insert(addr, &permission{ + st: permStatePermitted, + })) + + loggerFactory := logging.NewDefaultLoggerFactory() + log := loggerFactory.NewLogger("test") + alloc := TCPConn{ + RelayConn: RelayConn{ + obs: obs, + permMap: pm, + log: log, + }, + } + + cid, err := alloc.connect(addr) + assert.Equal(t, serverCid, cid) + assert.NoError(t, err) + }) +} diff --git a/internal/proto/reqtrans.go b/internal/proto/reqtrans.go index cc73a471..553fdb13 100644 --- a/internal/proto/reqtrans.go +++ b/internal/proto/reqtrans.go @@ -10,6 +10,8 @@ import ( type Protocol byte const ( + // ProtoTCP is IANA assigned protocol number for TCP. + ProtoTCP Protocol = 6 // ProtoUDP is IANA assigned protocol number for UDP. ProtoUDP Protocol = 17 ) diff --git a/internal/server/errors.go b/internal/server/errors.go index 13f8ee1a..fc9a45b1 100644 --- a/internal/server/errors.go +++ b/internal/server/errors.go @@ -15,7 +15,7 @@ var ( errFailedToCreateSTUNPacket = errors.New("failed to create stun message from packet") errFailedToCreateChannelData = errors.New("failed to create channel data from packet") errRelayAlreadyAllocatedForFiveTuple = errors.New("relay already allocated for 5-TUPLE") - errRequestedTransportMustBeUDP = errors.New("RequestedTransport must be UDP") + errUnsupportedTransportProtocol = errors.New("RequestedTransport must be UDP or TCP") errNoDontFragmentSupport = errors.New("no support for DONT-FRAGMENT") errRequestWithReservationTokenAndEvenPort = errors.New("Request must not contain RESERVATION-TOKEN and EVEN-PORT") errNoAllocationFound = errors.New("no allocation found") diff --git a/internal/server/turn.go b/internal/server/turn.go index 4e7d25db..3ea75ecb 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -53,14 +53,14 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // attribute. If the REQUESTED-TRANSPORT attribute is not included // or is malformed, the server rejects the request with a 400 (Bad // Request) error. Otherwise, if the attribute is included but - // specifies a protocol other that UDP, the server rejects the + // specifies a protocol other that UDP/TCP, the server rejects the // request with a 442 (Unsupported Transport Protocol) error. var requestedTransport proto.RequestedTransport if err = requestedTransport.GetFrom(m); err != nil { return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) - } else if requestedTransport.Protocol != proto.ProtoUDP { + } else if requestedTransport.Protocol != proto.ProtoUDP && requestedTransport.Protocol != proto.ProtoTCP { msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeUnsupportedTransProto}) - return buildAndSendErr(r.Conn, r.SrcAddr, errRequestedTransportMustBeUDP, msg...) + return buildAndSendErr(r.Conn, r.SrcAddr, errUnsupportedTransportProtocol, msg...) } // 4. The request may contain a DONT-FRAGMENT attribute. If it does,