Skip to content

Commit

Permalink
Implement client side for TCP allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyLc committed Mar 15, 2023
1 parent 7abfa3b commit 9bee828
Show file tree
Hide file tree
Showing 14 changed files with 788 additions and 123 deletions.
148 changes: 109 additions & 39 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type ClientConfig struct {

// Client is a STUN server client
type Client struct {
conn net.PacketConn // read-only
conn net.PacketConn // read-only // connection to the regular turn server.
stunServ net.Addr // read-only
turnServ net.Addr // read-only
stunServStr string // read-only, used for de-multiplexing
Expand All @@ -61,7 +61,7 @@ type Client struct {
software stun.Software // read-only
trMap *client.TransactionMap // thread-safe
rto time.Duration // read-only
relayedConn *client.UDPConn // protected by mutex ***
relayedConn client.Conn // protected by mutex ***
allocTryLock client.TryLock // thread-safe
listenTryLock client.TryLock // thread-safe
net transport.Net // read-only
Expand Down Expand Up @@ -238,42 +238,34 @@ func (c *Client) SendBindingRequest() (net.Addr, error) {
return c.SendBindingRequestTo(c.stunServ)
}

// Allocate sends a TURN allocation request to the given transport address
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.relayedUDPConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}
func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) {
var relayed proto.RelayedAddress
var lifetime proto.Lifetime
var nonce stun.Nonce

msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: protocol},
stun.Fingerprint,
)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

trRes, err := c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

res := trRes.Msg

// Anonymous allocate failed, trying to authenticate.
var nonce stun.Nonce
if err = nonce.GetFrom(res); err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
if err = c.realm.GetFrom(res); err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
c.realm = append([]byte(nil), c.realm...)
c.integrity = stun.NewLongTermIntegrity(
Expand All @@ -283,48 +275,101 @@ func (c *Client) Allocate() (net.PacketConn, error) {
msg, err = stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: protocol},
&c.username,
&c.realm,
&nonce,
&c.integrity,
stun.Fingerprint,
)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

trRes, err = c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
res = trRes.Msg

if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return nil, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return nil, fmt.Errorf("%s", res.Type) //nolint:goerr113
return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113
}

// Getting relayed addresses from response.
var relayed proto.RelayedAddress
if err := relayed.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}

// Getting lifetime from response
if err := lifetime.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}
return relayed, lifetime, nonce, nil
}

// Allocate sends a TURN allocation request to the given transport address
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.getRelayedConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}

relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoUDP)
if err != nil {
return nil, err
}

relayedAddr := &net.UDPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

// Getting lifetime from response
var lifetime proto.Lifetime
if err := lifetime.GetFrom(res); err != nil {
relayedConn = client.NewUDPConn(&client.ConnConfig{
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Nonce: nonce,
Lifetime: lifetime.Duration,
Log: c.log,
})
c.setRelayedConn(relayedConn)

return relayedConn.(*client.UDPConn), nil
}

// Allocate TCP
func (c *Client) AllocateTCP() (*client.TCPConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.getRelayedConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}

relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoTCP)
if err != nil {
return nil, err
}

relayedConn = client.NewUDPConn(&client.UDPConnConfig{
relayedAddr := &net.TCPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

relayedConn = client.NewTCPConn(&client.ConnConfig{
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Expand All @@ -333,15 +378,15 @@ func (c *Client) Allocate() (net.PacketConn, error) {
Log: c.log,
})

c.setRelayedUDPConn(relayedConn)
c.setRelayedConn(relayedConn)

return relayedConn, nil
return relayedConn.(*client.TCPConn), nil
}

// CreatePermission Issues a CreatePermission request for the supplied addresses
// as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9
func (c *Client) CreatePermission(addrs ...net.Addr) error {
return c.relayedUDPConn().CreatePermissions(addrs...)
return c.getRelayedConn().CreatePermissions(addrs...)
}

// PerformTransaction performs STUN transaction
Expand Down Expand Up @@ -386,7 +431,7 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult
// OnDeallocated is called when de-allocation of relay address has been complete.
// (Called by UDPConn)
func (c *Client) OnDeallocated(relayedAddr net.Addr) {
c.setRelayedUDPConn(nil)
c.setRelayedConn(nil)
}

// HandleInbound handles data received.
Expand Down Expand Up @@ -445,7 +490,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
}

if msg.Type.Class == stun.ClassIndication {
if msg.Type.Method == stun.MethodData {
switch msg.Type.Method {
case stun.MethodData:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
Expand All @@ -462,13 +508,37 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {

c.log.Debugf("data indication received from %s", from.String())

relayedConn := c.relayedUDPConn()
relayedConn := c.getRelayedConn().(*client.UDPConn)
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
}

relayedConn.HandleInbound(data, from)
case stun.MethodConnectionAttempt:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
}

from = &net.TCPAddr{
IP: peerAddr.IP,
Port: peerAddr.Port,
}

var cid proto.ConnectionID
if err := cid.GetFrom(msg); err != nil {
return err
}

c.log.Debugf("connection attempt from %s", from.String())

relayedConn := c.getRelayedConn().(*client.TCPConn)
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
}

relayedConn.HandleConnectionAttempt(data, from, cid)
}
return nil
}
Expand Down Expand Up @@ -514,7 +584,7 @@ func (c *Client) handleChannelData(data []byte) error {
return err
}

relayedConn := c.relayedUDPConn()
relayedConn := c.getRelayedConn().(*client.UDPConn)
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
Expand Down Expand Up @@ -566,14 +636,14 @@ func (c *Client) onRtxTimeout(trKey string, nRtx int) {
tr.StartRtxTimer(c.onRtxTimeout)
}

func (c *Client) setRelayedUDPConn(conn *client.UDPConn) {
func (c *Client) setRelayedConn(conn client.Conn) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.relayedConn = conn
}

func (c *Client) relayedUDPConn() *client.UDPConn {
func (c *Client) getRelayedConn() client.Conn {
c.mutex.RLock()
defer c.mutex.RUnlock()

Expand Down
50 changes: 50 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,53 @@ func TestClientNonceExpiration(t *testing.T) {
assert.NoError(t, conn.Close())
assert.NoError(t, server.Close())
}

// Create a tcp-based allocation and verify allocation can be created
func TestTCPClient(t *testing.T) {
// Setup server
tcpListener, err := net.Listen("tcp4", "0.0.0.0:3478")
assert.NoError(t, err)

server, err := NewServer(ServerConfig{
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
return GenerateAuthKey(username, realm, "pass"), true
},
ListenerConfigs: []ListenerConfig{
{
Listener: tcpListener,
RelayAddressGenerator: &RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP("127.0.0.1"),
Address: "0.0.0.0",
},
},
},
Realm: "pion.ly",
})
assert.NoError(t, err)

// Setup clients
conn, err := net.Dial("tcp", "127.0.0.1:3478")
assert.NoError(t, err)

client, err := NewClient(&ClientConfig{
Conn: NewSTUNConn(conn),
STUNServerAddr: "127.0.0.1:3478",
TURNServerAddr: "127.0.0.1:3478",
Username: "foo",
Password: "pass",
})
assert.NoError(t, err)
assert.NoError(t, client.Listen())

allocation, err := client.AllocateTCP()
assert.NoError(t, err)

// TODO: Implement server side handling of Connect and ConnectionBind
// _, err = allocation.Dial(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080})
// assert.NoError(t, err)

// Shutdown
assert.NoError(t, allocation.Close())
assert.NoError(t, conn.Close())
assert.NoError(t, server.Close())
}
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ var (
errNonSTUNMessage = errors.New("non-STUN message from STUN server")
errFailedToDecodeSTUN = errors.New("failed to decode STUN message")
errUnexpectedSTUNRequestMessage = errors.New("unexpected STUN request message")
errInvalidProtocol = errors.New("unexpected STUN request message")
)
Loading

0 comments on commit 9bee828

Please sign in to comment.