Skip to content

Commit

Permalink
Implement active TCP candidate type
Browse files Browse the repository at this point in the history
  • Loading branch information
ashellunts committed May 12, 2023
1 parent 9b4e7d9 commit 012cd28
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 44 deletions.
48 changes: 47 additions & 1 deletion agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"fmt"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -579,6 +580,44 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
}

func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
if local.TCPType() == TCPTypeActive && remote.TCPType() == TCPTypeActive {
return nil
}

if local.TCPType() == TCPTypeActive && remote.TCPType() == TCPTypePassive {
addressToConnect := net.JoinHostPort(remote.Address(), strconv.Itoa(remote.Port()))

conn, err := net.Dial("tcp", addressToConnect)
if err != nil {
a.log.Errorf("Failed to dial TCP address %s: %v", addressToConnect, err)
return nil
}

packetConn := newTCPPacketConn(tcpPacketParams{
ReadBuffer: tcpReadBufferSize,
LocalAddr: conn.LocalAddr(),
Logger: a.log,
})

if err = packetConn.AddConn(conn, nil); err != nil {
a.log.Errorf("Failed to add TCP connection: %v", err)
return nil
}

localAddress, ok := conn.LocalAddr().(*net.TCPAddr)
if !ok {
a.log.Errorf("Failed to cast local address to TCP address")
return nil
}

localCandidateHost, ok := local.(*CandidateHost)
if !ok {
a.log.Errorf("Failed to cast local candidate to CandidateHost")
return nil
}
localCandidateHost.port = localAddress.Port // this causes a data race with candidateBase.Port()
local.start(a, packetConn, a.startedCh)
}
p := newCandidatePair(local, remote, a.isControlling)
a.checklist = append(a.checklist, p)
return p
Expand Down Expand Up @@ -755,7 +794,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net
}
}

c.start(a, candidateConn, a.startedCh)
if c.TCPType() != TCPTypeActive {
c.start(a, candidateConn, a.startedCh)
}

set = append(set, c)
a.localCandidates[c.NetworkType()] = set
Expand Down Expand Up @@ -1023,13 +1064,18 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
return
}

tcpType := TCPTypeUnspecified
if networkType == NetworkTypeTCP4 && local.NetworkType() == NetworkTypeTCP4 && local.TCPType() == TCPTypePassive {
tcpType = TCPTypeActive
}
prflxCandidateConfig := CandidatePeerReflexiveConfig{
Network: networkType.String(),
Address: ip.String(),
Port: port,
Component: local.Component(),
RelAddr: "",
RelPort: 0,
TCPType: tcpType,
}

prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig)
Expand Down
91 changes: 91 additions & 0 deletions agent_active_tcp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

//go:build !js
// +build !js

package ice

import (
"net"
"testing"

"github.com/pion/logging"
"github.com/stretchr/testify/require"
)

func TestAgentActiveTCP(t *testing.T) {
r := require.New(t)

const port = 7686

listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: port,
})
r.NoError(err)
defer func() {
_ = listener.Close()
}()

loggerFactory := logging.NewDefaultLoggerFactory()
loggerFactory.DefaultLogLevel.Set(logging.LogLevelTrace)

tcpMux := NewTCPMuxDefault(TCPMuxParams{
Listener: listener,
Logger: loggerFactory.NewLogger("passive-ice-tcp-mux"),
ReadBufferSize: 20,
})

defer func() {
_ = tcpMux.Close()
}()

r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")

passiveAgent, err := NewAgent(&AgentConfig{
TCPMux: tcpMux,
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: []NetworkType{NetworkTypeTCP4},
LoggerFactory: loggerFactory,
IncludeLoopback: true,
})
r.NoError(err)
r.NotNil(passiveAgent)

activeAgent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: []NetworkType{NetworkTypeTCP4},
LoggerFactory: loggerFactory,
})
r.NoError(err)
r.NotNil(activeAgent)

passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent)
r.NotNil(passiveAgentConn)
r.NotNil(activeAgenConn)

pair := passiveAgent.getSelectedPair()
r.NotNil(pair)
r.Equal(port, pair.Local.Port())

data := []byte("hello world")
_, err = passiveAgentConn.Write(data)
r.NoError(err)

buffer := make([]byte, 1024)
n, err := activeAgenConn.Read(buffer)
r.NoError(err)
r.Equal(data, buffer[:n])

data2 := []byte("hello world 2")
_, err = activeAgenConn.Write(data2)
r.NoError(err)

n, err = passiveAgentConn.Read(buffer)
r.NoError(err)
r.Equal(data2, buffer[:n])

r.NoError(activeAgenConn.Close())
r.NoError(passiveAgentConn.Close())
}
3 changes: 3 additions & 0 deletions agent_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ const (

// maxBindingRequestTimeout is the wait time before binding requests can be deleted
maxBindingRequestTimeout = 4000 * time.Millisecond

// tcpReadBufferSize is the size of the read buffer of tcpPacketConn used by active tcp candidate
tcpReadBufferSize = 8
)

func defaultCandidateTypes() []CandidateType {
Expand Down
4 changes: 2 additions & 2 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1637,7 +1637,7 @@ func TestAcceptAggressiveNomination(t *testing.T) {

KeepaliveInterval := time.Hour
cfg0 := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net0,

Expand All @@ -1652,7 +1652,7 @@ func TestAcceptAggressiveNomination(t *testing.T) {
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))

cfg1 := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net1,
KeepaliveInterval: &KeepaliveInterval,
Expand Down
2 changes: 1 addition & 1 deletion candidate_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func UnmarshalCandidate(raw string) (Candidate, error) {
case "srflx":
return NewCandidateServerReflexive(&CandidateServerReflexiveConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort})
case "prflx":
return NewCandidatePeerReflexive(&CandidatePeerReflexiveConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort})
return NewCandidatePeerReflexive(&CandidatePeerReflexiveConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort, tcpType})
case "relay":
return NewCandidateRelay(&CandidateRelayConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort, "", nil})
default:
Expand Down
2 changes: 2 additions & 0 deletions candidate_peer_reflexive.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type CandidatePeerReflexiveConfig struct {
Foundation string
RelAddr string
RelPort int
TCPType TCPType
}

// NewCandidatePeerReflexive creates a new peer reflective candidate
Expand All @@ -49,6 +50,7 @@ func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*Candidate
id: candidateID,
networkType: networkType,
candidateType: CandidateTypePeerReflexive,
tcpType: config.TCPType,
address: config.Address,
port: config.Port,
resolvedAddr: createAddr(networkType, ip, config.Port),
Expand Down
83 changes: 44 additions & 39 deletions gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ const (
stunGatherTimeout = time.Second * 5
)

type connAndPort struct {
conn net.PacketConn
port int
tcpType TCPType
}

// Close a net.Conn and log if we have a failure
func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ...interface{}) {
if c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) {
Expand Down Expand Up @@ -155,53 +161,21 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
}

for network := range networks {
type connAndPort struct {
conn net.PacketConn
port int
}
var (
conns []connAndPort
tcpType TCPType
)
var conns []connAndPort

switch network {
case tcp:
if a.tcpMux == nil {
continue
}
// Handle ICE TCP active mode
conns = append(conns, connAndPort{nil, 0, TCPTypeActive})

// Handle ICE TCP passive mode
var muxConns []net.PacketConn
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
if err != nil {
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
continue
}
} else {
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
if err != nil {
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
continue
}
muxConns = []net.PacketConn{conn}
}

// Extract the port for each PacketConn we got.
for _, conn := range muxConns {
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
conns = append(conns, connAndPort{conn, tcpConn.Port})
} else {
a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag)
}
if a.tcpMux != nil {
conns = a.getTCPMuxConnections(mappedIP, ip, network, conns)
}
if len(conns) == 0 {
// Didn't succeed with any, try the next network.
continue
}
tcpType = TCPTypePassive
// Is there a way to verify that the listen address is even
// accessible from the current interface.
case udp:
Expand All @@ -212,7 +186,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
}

if udpConn, ok := conn.LocalAddr().(*net.UDPAddr); ok {
conns = append(conns, connAndPort{conn, udpConn.Port})
conns = append(conns, connAndPort{conn, udpConn.Port, TCPTypeUnspecified})
} else {
a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, ip, a.localUfrag)
continue
Expand All @@ -225,7 +199,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
Address: address,
Port: connAndPort.port,
Component: ComponentRTP,
TCPType: tcpType,
TCPType: connAndPort.tcpType,
}

c, err := NewCandidateHost(&hostConfig)
Expand All @@ -252,6 +226,37 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
}
}

func (a *Agent) getTCPMuxConnections(mappedIP net.IP, ip net.IP, network string, conns []connAndPort) []connAndPort {
var muxConns []net.PacketConn
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
var err error
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
if err != nil {
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
return conns
}
} else {
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
if err != nil {
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
return conns
}
muxConns = []net.PacketConn{conn}
}

// Extract the port for each PacketConn we got.
for _, conn := range muxConns {
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
conns = append(conns, connAndPort{conn, tcpConn.Port, TCPTypePassive})
} else {
a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag)
}
}
return conns
}

func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit
if a.udpMux == nil {
return errUDPMuxDisabled
Expand Down
2 changes: 1 addition & 1 deletion gather_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ func TestMultiUDPMuxUsage(t *testing.T) {
}

a, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
CandidateTypes: []CandidateType{CandidateTypeHost},
UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...),
})
Expand Down

0 comments on commit 012cd28

Please sign in to comment.