Skip to content

Commit

Permalink
Implement client side for TCP allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyLc committed May 4, 2023
1 parent 7abfa3b commit e4ae7e4
Show file tree
Hide file tree
Showing 13 changed files with 917 additions and 116 deletions.
161 changes: 129 additions & 32 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Client struct {
trMap *client.TransactionMap // thread-safe
rto time.Duration // read-only
relayedConn *client.UDPConn // protected by mutex ***
tcpAllocation *client.TCPAllocation // protected by mutex ***
allocTryLock client.TryLock // thread-safe
listenTryLock client.TryLock // thread-safe
net transport.Net // read-only
Expand Down Expand Up @@ -238,42 +239,34 @@ func (c *Client) SendBindingRequest() (net.Addr, error) {
return c.SendBindingRequestTo(c.stunServ)
}

// Allocate sends a TURN allocation request to the given transport address
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.relayedUDPConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}
func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) {
var relayed proto.RelayedAddress
var lifetime proto.Lifetime
var nonce stun.Nonce

msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: protocol},
stun.Fingerprint,
)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

trRes, err := c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

res := trRes.Msg

// Anonymous allocate failed, trying to authenticate.
var nonce stun.Nonce
if err = nonce.GetFrom(res); err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
if err = c.realm.GetFrom(res); err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
c.realm = append([]byte(nil), c.realm...)
c.integrity = stun.NewLongTermIntegrity(
Expand All @@ -283,48 +276,101 @@ func (c *Client) Allocate() (net.PacketConn, error) {
msg, err = stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: protocol},
&c.username,
&c.realm,
&nonce,
&c.integrity,
stun.Fingerprint,
)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

trRes, err = c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
res = trRes.Msg

if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return nil, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return nil, fmt.Errorf("%s", res.Type) //nolint:goerr113
return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113
}

// Getting relayed addresses from response.
var relayed proto.RelayedAddress
if err := relayed.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}

// Getting lifetime from response
if err := lifetime.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}
return relayed, lifetime, nonce, nil
}

// Allocate sends a TURN allocation request to the given transport address
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.relayedUDPConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}

relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoUDP)
if err != nil {
return nil, err
}

relayedAddr := &net.UDPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

// Getting lifetime from response
var lifetime proto.Lifetime
if err := lifetime.GetFrom(res); err != nil {
relayedConn = client.NewUDPConn(&client.ConnConfig{
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Nonce: nonce,
Lifetime: lifetime.Duration,
Log: c.log,
})
c.setRelayedUDPConn(relayedConn)

return relayedConn, nil
}

// Allocate TCP
func (c *Client) AllocateTCP() (*client.TCPAllocation, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

allocation := c.getTCPAllocation()
if allocation != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, allocation.Addr().String())
}

relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoTCP)
if err != nil {
return nil, err
}

relayedConn = client.NewUDPConn(&client.UDPConnConfig{
relayedAddr := &net.TCPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

allocation = client.NewTCPAllocation(&client.ConnConfig{
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Expand All @@ -333,15 +379,26 @@ func (c *Client) Allocate() (net.PacketConn, error) {
Log: c.log,
})

c.setRelayedUDPConn(relayedConn)
c.setTCPAllocation(allocation)

return relayedConn, nil
return allocation, nil
}

// CreatePermission Issues a CreatePermission request for the supplied addresses
// as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9
func (c *Client) CreatePermission(addrs ...net.Addr) error {
return c.relayedUDPConn().CreatePermissions(addrs...)
if conn := c.relayedUDPConn(); conn != nil {
if err := conn.CreatePermissions(addrs...); err != nil {
return err
}
}

if allocation := c.getTCPAllocation(); allocation != nil {
if err := allocation.CreatePermissions(addrs...); err != nil {
return err
}
}
return nil
}

// PerformTransaction performs STUN transaction
Expand Down Expand Up @@ -387,6 +444,7 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult
// (Called by UDPConn)
func (c *Client) OnDeallocated(relayedAddr net.Addr) {
c.setRelayedUDPConn(nil)
c.setTCPAllocation(nil)
}

// HandleInbound handles data received.
Expand Down Expand Up @@ -445,7 +503,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
}

if msg.Type.Class == stun.ClassIndication {
if msg.Type.Method == stun.MethodData {
switch msg.Type.Method {
case stun.MethodData:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
Expand All @@ -467,8 +526,32 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
}

relayedConn.HandleInbound(data, from)
case stun.MethodConnectionAttempt:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
}

addr := &net.TCPAddr{
IP: peerAddr.IP,
Port: peerAddr.Port,
}

var cid proto.ConnectionID
if err := cid.GetFrom(msg); err != nil {
return err
}

c.log.Debugf("connection attempt from %s", addr.String())

allocation := c.getTCPAllocation()
if allocation == nil {
c.log.Debug("no TCP allocation exists")
return nil // silently discard
}

allocation.HandleConnectionAttempt(addr, cid)
}
return nil
}
Expand Down Expand Up @@ -579,3 +662,17 @@ func (c *Client) relayedUDPConn() *client.UDPConn {

return c.relayedConn
}

func (c *Client) setTCPAllocation(alloc *client.TCPAllocation) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.tcpAllocation = alloc
}

func (c *Client) getTCPAllocation() *client.TCPAllocation {
c.mutex.RLock()
defer c.mutex.RUnlock()

return c.tcpAllocation
}
51 changes: 51 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/pion/logging"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func createListeningTestClient(t *testing.T, loggerFactory logging.LoggerFactory) (*Client, net.PacketConn, bool) {
Expand Down Expand Up @@ -187,3 +188,53 @@ func TestClientNonceExpiration(t *testing.T) {
assert.NoError(t, conn.Close())
assert.NoError(t, server.Close())
}

// Create a TCP-based allocation and verify allocation can be created
func TestTCPClient(t *testing.T) {
// Setup server
tcpListener, err := net.Listen("tcp4", "0.0.0.0:13478")
require.NoError(t, err)

server, err := NewServer(ServerConfig{
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
return GenerateAuthKey(username, realm, "pass"), true
},
ListenerConfigs: []ListenerConfig{
{
Listener: tcpListener,
RelayAddressGenerator: &RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP("127.0.0.1"),
Address: "0.0.0.0",
},
},
},
Realm: "pion.ly",
})
require.NoError(t, err)

// Setup clients
conn, err := net.Dial("tcp", "127.0.0.1:13478")
require.NoError(t, err)

client, err := NewClient(&ClientConfig{
Conn: NewSTUNConn(conn),
STUNServerAddr: "127.0.0.1:13478",
TURNServerAddr: "127.0.0.1:13478",
Username: "foo",
Password: "pass",
})
require.NoError(t, err)
require.NoError(t, client.Listen())

allocation, err := client.AllocateTCP()
require.NoError(t, err)

// TODO: Implement server side handling of Connect and ConnectionBind
// _, err = allocation.Dial(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080})
// assert.NoError(t, err)

// Shutdown
require.NoError(t, allocation.Close())
require.NoError(t, conn.Close())
require.NoError(t, server.Close())
}
Loading

0 comments on commit e4ae7e4

Please sign in to comment.