Skip to content

Commit

Permalink
Implement client side for TCP allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyLc committed Mar 14, 2023
1 parent 7abfa3b commit dd6aabc
Show file tree
Hide file tree
Showing 14 changed files with 766 additions and 230 deletions.
239 changes: 178 additions & 61 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type ClientConfig struct {
Password string
Realm string
Software string
Protocol string
RTO time.Duration
Conn net.PacketConn // Listening socket (net.PacketConn)
LoggerFactory logging.LoggerFactory
Expand All @@ -49,25 +50,27 @@ type ClientConfig struct {

// Client is a STUN server client
type Client struct {
conn net.PacketConn // read-only
stunServ net.Addr // read-only
turnServ net.Addr // read-only
stunServStr string // read-only, used for de-multiplexing
turnServStr string // read-only, used for de-multiplexing
username stun.Username // read-only
password string // read-only
realm stun.Realm // read-only
integrity stun.MessageIntegrity // read-only
software stun.Software // read-only
trMap *client.TransactionMap // thread-safe
rto time.Duration // read-only
relayedConn *client.UDPConn // protected by mutex ***
allocTryLock client.TryLock // thread-safe
listenTryLock client.TryLock // thread-safe
net transport.Net // read-only
mutex sync.RWMutex // thread-safe
mutexTrMap sync.Mutex // thread-safe
log logging.LeveledLogger // read-only
conn net.PacketConn // read-only // connection to the regular turn server.
stunServ net.Addr // read-only
turnServ net.Addr // read-only
stunServStr string // read-only, used for de-multiplexing
turnServStr string // read-only, used for de-multiplexing
username stun.Username // read-only
password string // read-only
realm stun.Realm // read-only
integrity stun.MessageIntegrity // read-only
software stun.Software // read-only
trMap *client.TransactionMap // thread-safe
rto time.Duration // read-only
relayedConn *client.UDPConn // protected by mutex ***
relayedTCPConn *client.TCPConn // protected by mutex ***
allocTryLock client.TryLock // thread-safe
listenTryLock client.TryLock // thread-safe
net transport.Net // read-only
mutex sync.RWMutex // thread-safe
mutexTrMap sync.Mutex // thread-safe
log logging.LeveledLogger // read-only
protocol proto.Protocol // read-only
}

// NewClient returns a new Client instance. listeningAddress is the address and port to listen on, default "0.0.0.0:0"
Expand All @@ -93,24 +96,47 @@ func NewClient(config *ClientConfig) (*Client, error) {
log.Warn("Virtual network is enabled")
}

protocol := proto.ProtoUDP
if config.Protocol == "tcp" {
protocol = proto.ProtoTCP
}

var stunServ, turnServ net.Addr
var stunServStr, turnServStr string
if len(config.STUNServerAddr) > 0 {
log.Debugf("resolving %s", config.STUNServerAddr)
stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr)
if err != nil {
return nil, err
switch protocol {
case proto.ProtoUDP:
stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr)
if err != nil {
return nil, err
}
stunServStr = stunServ.String()
case proto.ProtoTCP:
stunServ, err = config.Net.ResolveTCPAddr("tcp4", config.STUNServerAddr)
if err != nil {
return nil, err
}
stunServStr = stunServ.String()
}
stunServStr = stunServ.String()
log.Debugf("stunServ: %s", stunServStr)
}
if len(config.TURNServerAddr) > 0 {
log.Debugf("resolving %s", config.TURNServerAddr)
turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr)
if err != nil {
return nil, err
switch protocol {
case proto.ProtoUDP:
turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr)
if err != nil {
return nil, err
}
turnServStr = turnServ.String()
case proto.ProtoTCP:
turnServ, err = config.Net.ResolveTCPAddr("tcp4", config.TURNServerAddr)
if err != nil {
return nil, err
}
turnServStr = turnServ.String()
}
turnServStr = turnServ.String()
log.Debugf("turnServ: %s", turnServStr)
}

Expand All @@ -133,6 +159,7 @@ func NewClient(config *ClientConfig) (*Client, error) {
trMap: client.NewTransactionMap(),
rto: rto,
log: log,
protocol: protocol,
}

return c, nil
Expand Down Expand Up @@ -238,42 +265,34 @@ func (c *Client) SendBindingRequest() (net.Addr, error) {
return c.SendBindingRequestTo(c.stunServ)
}

// Allocate sends a TURN allocation request to the given transport address
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.relayedUDPConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}
func (c *Client) sendAllocateRequest() (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) {
var relayed proto.RelayedAddress
var lifetime proto.Lifetime
var nonce stun.Nonce

msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: c.protocol},
stun.Fingerprint,
)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

trRes, err := c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

res := trRes.Msg

// Anonymous allocate failed, trying to authenticate.
var nonce stun.Nonce
if err = nonce.GetFrom(res); err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
if err = c.realm.GetFrom(res); err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
c.realm = append([]byte(nil), c.realm...)
c.integrity = stun.NewLongTermIntegrity(
Expand All @@ -283,48 +302,101 @@ func (c *Client) Allocate() (net.PacketConn, error) {
msg, err = stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassRequest),
proto.RequestedTransport{Protocol: proto.ProtoUDP},
proto.RequestedTransport{Protocol: c.protocol},
&c.username,
&c.realm,
&nonce,
&c.integrity,
stun.Fingerprint,
)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}

trRes, err = c.PerformTransaction(msg, c.turnServ, false)
if err != nil {
return nil, err
return relayed, lifetime, nonce, err
}
res = trRes.Msg

if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return nil, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return nil, fmt.Errorf("%s", res.Type) //nolint:goerr113
return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113
}

// Getting relayed addresses from response.
var relayed proto.RelayedAddress
if err := relayed.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}

// Getting lifetime from response
if err := lifetime.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}
return relayed, lifetime, nonce, nil
}

// Allocate sends a TURN allocation request to the given transport address
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.getRelayedUDPConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}

relayed, lifetime, nonce, err := c.sendAllocateRequest()
if err != nil {
return nil, err
}

relayedAddr := &net.UDPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

// Getting lifetime from response
var lifetime proto.Lifetime
if err := lifetime.GetFrom(res); err != nil {
relayedConn = client.NewUDPConn(&client.ConnConfig{
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Nonce: nonce,
Lifetime: lifetime.Duration,
Log: c.log,
})
c.setRelayedUDPConn(relayedConn)

return relayedConn, nil
}

// Allocate TCP
func (c *Client) AllocateTCP() (*client.TCPConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
}
defer c.allocTryLock.Unlock()

relayedConn := c.getRelayedTCPConn()
if relayedConn != nil {
return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String())
}

relayed, lifetime, nonce, err := c.sendAllocateRequest()
if err != nil {
return nil, err
}

relayedConn = client.NewUDPConn(&client.UDPConnConfig{
relayedAddr := &net.TCPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

relayedConn = client.NewTCPConn(&client.ConnConfig{
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Expand All @@ -333,15 +405,21 @@ func (c *Client) Allocate() (net.PacketConn, error) {
Log: c.log,
})

c.setRelayedUDPConn(relayedConn)
c.setRelayedTCPConn(relayedConn)

return relayedConn, nil
}

// CreatePermission Issues a CreatePermission request for the supplied addresses
// as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9
func (c *Client) CreatePermission(addrs ...net.Addr) error {
return c.relayedUDPConn().CreatePermissions(addrs...)
switch c.protocol {
case proto.ProtoUDP:
return c.getRelayedUDPConn().CreatePermissions(addrs...)
case proto.ProtoTCP:
return c.getRelayedTCPConn().CreatePermissions(addrs...)
}
return nil
}

// PerformTransaction performs STUN transaction
Expand Down Expand Up @@ -445,7 +523,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
}

if msg.Type.Class == stun.ClassIndication {
if msg.Type.Method == stun.MethodData {
switch msg.Type.Method {
case stun.MethodData:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
Expand All @@ -462,13 +541,37 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {

c.log.Debugf("data indication received from %s", from.String())

relayedConn := c.relayedUDPConn()
relayedConn := c.getRelayedUDPConn()
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
}

relayedConn.HandleInbound(data, from)
case stun.MethodConnectionAttempt:
var peerAddr proto.PeerAddress
if err := peerAddr.GetFrom(msg); err != nil {
return err
}

from = &net.TCPAddr{
IP: peerAddr.IP,
Port: peerAddr.Port,
}

var cid proto.ConnectionID
if err := cid.GetFrom(msg); err != nil {
return err
}

c.log.Debugf("connection attempt from %s", from.String())

relayedConn := c.getRelayedTCPConn()
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
}

relayedConn.HandleConnectionAttempt(data, from, cid)
}
return nil
}
Expand Down Expand Up @@ -514,7 +617,7 @@ func (c *Client) handleChannelData(data []byte) error {
return err
}

relayedConn := c.relayedUDPConn()
relayedConn := c.getRelayedUDPConn()
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
Expand Down Expand Up @@ -573,9 +676,23 @@ func (c *Client) setRelayedUDPConn(conn *client.UDPConn) {
c.relayedConn = conn
}

func (c *Client) relayedUDPConn() *client.UDPConn {
func (c *Client) getRelayedUDPConn() *client.UDPConn {
c.mutex.RLock()
defer c.mutex.RUnlock()

return c.relayedConn
}

func (c *Client) setRelayedTCPConn(conn *client.TCPConn) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.relayedTCPConn = conn
}

func (c *Client) getRelayedTCPConn() *client.TCPConn {
c.mutex.RLock()
defer c.mutex.RUnlock()

return c.relayedTCPConn
}
Loading

0 comments on commit dd6aabc

Please sign in to comment.