diff --git a/config/config.go b/config/config.go index 8c588e6c..c742c41b 100644 --- a/config/config.go +++ b/config/config.go @@ -35,16 +35,16 @@ const ( type Config struct { LogLevel LogLevel `default:"info" split_words:"true"` - ExternalIP string `split_words:"true"` + ExternalIP []string `split_words:"true"` TLSCertFile string `split_words:"true"` TLSKeyFile string `split_words:"true"` ServerTLS bool `split_words:"true"` - ServerAddress string `default:"0.0.0.0:5050" split_words:"true"` + ServerAddress string `default:":5050" split_words:"true"` Secret []byte `split_words:"true"` - TurnAddress string `default:"0.0.0.0:3478" required:"true" split_words:"true"` + TurnAddress string `default:":3478" required:"true" split_words:"true"` TurnStrictAuth bool `default:"true" split_words:"true"` TurnPortRange string `split_words:"true"` @@ -54,7 +54,9 @@ type Config struct { UsersFile string `split_words:"true"` Prometheus bool `split_words:"true"` - CheckOrigin func(string) bool `ignored:"true" json:"-"` + CheckOrigin func(string) bool `ignored:"true" json:"-"` + ExternalIPV4 net.IP `ignored:"true"` + ExternalIPV6 net.IP `ignored:"true"` } func (c Config) parsePortRange() (uint16, uint16, error) { @@ -124,10 +126,6 @@ func Get() (Config, []FutureLog) { futureFatal(fmt.Sprintf("invalid SCREEGO_AUTH_MODE: %s", config.AuthMode))) } - if config.ExternalIP == "" { - logs = append(logs, futureFatal("SCREEGO_EXTERNAL_IP must be set")) - } - if config.ServerTLS { if config.TLSCertFile == "" { logs = append(logs, futureFatal("SCREEGO_TLS_CERT_FILE must be set if TLS is enabled")) @@ -170,9 +168,9 @@ func Get() (Config, []FutureLog) { } } - if net.ParseIP(config.ExternalIP) == nil || config.ExternalIP == "0.0.0.0" { - logs = append(logs, futureFatal(fmt.Sprintf("invalid SCREEGO_EXTERNAL_IP: %s", config.ExternalIP))) - } + var errs []FutureLog + config.ExternalIPV4, config.ExternalIPV6, errs = validateExternalIP(config.ExternalIP) + logs = append(logs, errs...) min, max, err := config.parsePortRange() if err != nil { @@ -192,6 +190,50 @@ func Get() (Config, []FutureLog) { return config, logs } +func validateExternalIP(ips []string) (net.IP, net.IP, []FutureLog) { + if len(ips) == 0 { + return nil, nil, []FutureLog{futureFatal("SCREEGO_EXTERNAL_IP must be set")} + } + + first := ips[0] + + firstParsed := net.ParseIP(first) + if firstParsed == nil || first == "0.0.0.0" { + return nil, nil, []FutureLog{futureFatal(fmt.Sprintf("invalid SCREEGO_EXTERNAL_IP: %s", first))} + } + firstIsIP4 := firstParsed.To4() != nil + + if len(ips) == 1 { + if firstIsIP4 { + return firstParsed, nil, nil + } + return nil, firstParsed, nil + } + + second := ips[1] + + secondParsed := net.ParseIP(second) + if secondParsed == nil || second == "0.0.0.0" { + return nil, nil, []FutureLog{futureFatal(fmt.Sprintf("invalid SCREEGO_EXTERNAL_IP: %s", second))} + } + + secondIsIP4 := secondParsed.To4() != nil + + if firstIsIP4 == secondIsIP4 { + return nil, nil, []FutureLog{futureFatal("invalid SCREEGO_EXTERNAL_IP: the ips must be of different type ipv4/ipv6")} + } + + if len(ips) > 2 { + return nil, nil, []FutureLog{futureFatal("invalid SCREEGO_EXTERNAL_IP: too many ips supplied")} + } + + if !firstIsIP4 { + return secondParsed, firstParsed, nil + } + + return firstParsed, secondParsed, nil +} + func getExecutableOrWorkDir() (string, *FutureLog) { dir, err := getExecutableDir() // when using `go run main.go` the executable lives in th temp directory therefore the env.development diff --git a/go.mod b/go.mod index 49307b86..f1368a04 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/kr/pretty v0.2.0 // indirect github.com/magiconair/properties v1.8.1 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 + github.com/pion/randutil v0.1.0 github.com/pion/turn/v2 v2.0.5 github.com/prometheus/client_golang v1.7.1 github.com/rs/xid v1.2.1 diff --git a/screego.config.example b/screego.config.example index 21c66890..a03c0570 100644 --- a/screego.config.example +++ b/screego.config.example @@ -1,7 +1,10 @@ # The external ip of the server. +# When using a dual stack setup define both IPv4 & IPv6 separated by a comma. # Execute the following command on the server you want to host Screego # to find your external ip. # curl 'https://api.ipify.org' +# Example: +# 192.168.178.2,2a01:c22:a87c:e500:2d8:61ff:fec7:f92a SCREEGO_EXTERNAL_IP= # A secret which should be unique. Is used for cookie authentication. diff --git a/turn/none.go b/turn/none.go new file mode 100644 index 00000000..0b26a8bd --- /dev/null +++ b/turn/none.go @@ -0,0 +1,26 @@ +package turn + +import ( + "errors" + "net" + "strconv" +) + +type RelayAddressGeneratorNone struct {} + +func (r *RelayAddressGeneratorNone) Validate() error { + return nil +} + +func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { + conn, err := net.ListenPacket("udp", ":"+strconv.Itoa(requestedPort)) + if err != nil { + return nil, nil, err + } + + return conn, conn.LocalAddr(), nil +} + +func (r *RelayAddressGeneratorNone) AllocateConn(network string, requestedPort int) (net.Conn, net.Addr, error) { + return nil, nil, errors.New("todo") +} diff --git a/turn/portrange.go b/turn/portrange.go new file mode 100644 index 00000000..8dcb9545 --- /dev/null +++ b/turn/portrange.go @@ -0,0 +1,51 @@ +package turn + +import ( + "errors" + "fmt" + "net" + + "github.com/pion/randutil" +) + +type RelayAddressGeneratorPortRange struct { + MinPort uint16 + MaxPort uint16 + Rand randutil.MathRandomGenerator +} + +func (r *RelayAddressGeneratorPortRange) Validate() error { + if r.Rand == nil { + r.Rand = randutil.NewMathRandomGenerator() + } + + return nil +} + +func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { + if requestedPort != 0 { + conn, err := net.ListenPacket("udp", fmt.Sprintf(":%d", requestedPort)) + if err != nil { + return nil, nil, err + } + relayAddr := conn.LocalAddr().(*net.UDPAddr) + return conn, relayAddr, nil + } + + for try := 0; try < 10; try++ { + port := r.MinPort + uint16(r.Rand.Intn(int((r.MaxPort+1)-r.MinPort))) + conn, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port)) + if err != nil { + continue + } + + relayAddr := conn.LocalAddr().(*net.UDPAddr) + return conn, relayAddr, nil + } + + return nil, nil, errors.New("max retries exceeded") +} + +func (r *RelayAddressGeneratorPortRange) AllocateConn(network string, requestedPort int) (net.Conn, net.Addr, error) { + return nil, nil, errors.New("todo") +} diff --git a/turn/server.go b/turn/server.go index c92e969e..a7ff5213 100644 --- a/turn/server.go +++ b/turn/server.go @@ -12,11 +12,10 @@ import ( ) type Server struct { - TurnAddress string - StunAddress string - lock sync.RWMutex - strictAuth bool - lookup map[string]Entry + Port string + lock sync.RWMutex + strictAuth bool + lookup map[string]Entry } type Entry struct { @@ -26,46 +25,57 @@ type Entry struct { const Realm = "screego" -type LoggedGenerator struct { +type Generator struct { + ipv4 net.IP + ipv6 net.IP turn.RelayAddressGenerator } -func (r *LoggedGenerator) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +func (r *Generator) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { conn, addr, err := r.RelayAddressGenerator.AllocatePacketConn(network, requestedPort) + relayAddr := *addr.(*net.UDPAddr) + if r.ipv6 == nil || (relayAddr.IP.To4() != nil && r.ipv4 != nil) { + relayAddr.IP = r.ipv4 + } else { + relayAddr.IP = r.ipv6 + } if err == nil { - log.Debug().Str("addr", addr.String()).Str("network", network).Msg("TURN allocated") + log.Debug().Str("addr", addr.String()).Str("relayaddr", relayAddr.String()).Msg("TURN allocated") } - return conn, addr, err + return conn, &relayAddr, err } func Start(conf config.Config) (*Server, error) { - udpListener, err := net.ListenPacket("udp4", conf.TurnAddress) + udpListener, err := net.ListenPacket("udp", conf.TurnAddress) if err != nil { return nil, fmt.Errorf("udp: could not listen on %s: %s", conf.TurnAddress, err) } - tcpListener, err := net.Listen("tcp4", conf.TurnAddress) + tcpListener, err := net.Listen("tcp", conf.TurnAddress) if err != nil { return nil, fmt.Errorf("tcp: could not listen on %s: %s", conf.TurnAddress, err) } - split := strings.SplitN(conf.TurnAddress, ":", 2) + split := strings.Split(conf.TurnAddress, ":") svr := &Server{ - TurnAddress: fmt.Sprintf("turn:%s:%s", conf.ExternalIP, split[1]), - StunAddress: fmt.Sprintf("stun:%s:%s", conf.ExternalIP, split[1]), - lookup: map[string]Entry{}, - strictAuth: conf.TurnStrictAuth, + Port: split[len(split) - 1], + lookup: map[string]Entry{}, + strictAuth: conf.TurnStrictAuth, } - loggedGenerator := &LoggedGenerator{RelayAddressGenerator: generator(conf)} + gen := &Generator{ + ipv4: conf.ExternalIPV4, + ipv6: conf.ExternalIPV6, + RelayAddressGenerator: generator(conf), + } _, err = turn.NewServer(turn.ServerConfig{ Realm: Realm, AuthHandler: svr.authenticate, ListenerConfigs: []turn.ListenerConfig{ - {Listener: tcpListener, RelayAddressGenerator: loggedGenerator}, + {Listener: tcpListener, RelayAddressGenerator: gen}, }, PacketConnConfigs: []turn.PacketConnConfig{ - {PacketConn: udpListener, RelayAddressGenerator: loggedGenerator}, + {PacketConn: udpListener, RelayAddressGenerator: gen}, }, }) if err != nil { @@ -80,17 +90,9 @@ func generator(conf config.Config) turn.RelayAddressGenerator { min, max, useRange := conf.PortRange() if useRange { log.Debug().Uint16("min", min).Uint16("max", max).Msg("Using Port Range") - return &turn.RelayAddressGeneratorPortRange{ - RelayAddress: net.ParseIP(conf.ExternalIP), - Address: "0.0.0.0", - MinPort: min, - MaxPort: max, - } - } - return &turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(conf.ExternalIP), - Address: "0.0.0.0", + return &RelayAddressGeneratorPortRange{MinPort: min, MaxPort: max} } + return &RelayAddressGeneratorNone{} } func (a *Server) Allow(username, password string, addr net.IP) { diff --git a/ws/room.go b/ws/room.go index cf846b22..e0f52924 100644 --- a/ws/room.go +++ b/ws/room.go @@ -1,6 +1,7 @@ package ws import ( + "fmt" "net" "sort" @@ -38,14 +39,16 @@ func (r *Room) newSession(host, client xid.ID, rooms *Rooms) { Client: client, } sessionCreatedTotal.Inc() + clientUser := r.Users[client] + hostUser := r.Users[host] iceHost := []outgoing.ICEServer{} iceClient := []outgoing.ICEServer{} switch r.Mode { case ConnectionLocal: case ConnectionSTUN: - iceHost = []outgoing.ICEServer{{URLs: []string{rooms.turnServer.StunAddress}}} - iceClient = []outgoing.ICEServer{{URLs: []string{rooms.turnServer.StunAddress}}} + iceHost = []outgoing.ICEServer{{URLs: []string{rooms.address(hostUser, "stun")}}} + iceClient = []outgoing.ICEServer{{URLs: []string{rooms.address(clientUser, "stun")}}} case ConnectionTURN: hostPW := util.RandString(20) clientPW := util.RandString(20) @@ -55,16 +58,16 @@ func (r *Room) newSession(host, client xid.ID, rooms *Rooms) { rooms.turnServer.Allow(clientName, clientPW, r.Users[client].Addr) iceHost = []outgoing.ICEServer{{ URLs: []string{ - rooms.turnServer.TurnAddress, - rooms.turnServer.TurnAddress + "?transport=tcp", + rooms.address(hostUser, "turn"), + rooms.address(hostUser, "turn") + "?transport=tcp", }, Credential: hostPW, Username: hostName, }} iceClient = []outgoing.ICEServer{{ URLs: []string{ - rooms.turnServer.TurnAddress, - rooms.turnServer.TurnAddress + "?transport=tcp", + rooms.address(clientUser, "turn"), + rooms.address(clientUser, "turn") + "?transport=tcp", }, Credential: clientPW, Username: clientName, @@ -75,6 +78,16 @@ func (r *Room) newSession(host, client xid.ID, rooms *Rooms) { r.Users[client].Write <- outgoing.ClientSession{Peer: host, ID: id, ICEServers: iceClient} } +func (r *Rooms) address(user *User, prefix string) string { + var ip string + if r.config.ExternalIPV6 == nil || (user.Addr.To4() != nil && r.config.ExternalIPV4 != nil) { + ip = r.config.ExternalIPV4.String() + } else { + ip = fmt.Sprintf("[%s]", r.config.ExternalIPV6) + } + return fmt.Sprintf("%s:%s:%s", prefix, ip, r.turnServer.Port) +} + func (r *Room) closeSession(id xid.ID) { delete(r.Sessions, id) sessionClosedTotal.Inc()