From 6713f3454f1a64d54384deeceee7e863c9d97f18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Toni=20Sol=C3=A9?= Date: Tue, 4 Jun 2024 11:18:32 +0200 Subject: [PATCH] - Added in the Config structure a OnConnectionAttempt callback - Added examples to show how it can be used in a server application. --- config.go | 7 +++++ conn.go | 5 ++++ examples/listen/cid/main.go | 47 ++++++++++++++++++++++++++++++++++ examples/listen/psk/main.go | 47 ++++++++++++++++++++++++++++++++++ examples/listen/verify/main.go | 47 ++++++++++++++++++++++++++++++++++ 5 files changed, 153 insertions(+) diff --git a/config.go b/config.go index d765ecd91..3ecbf3fe0 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "crypto/x509" "io" + "net" "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" @@ -214,6 +215,12 @@ type Config struct { // CertificateRequestMessageHook, if not nil, is called when a Certificate Request // message is sent from a server. The returned handshake message replaces the original message. CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + // Whenever a connection attempt is made, the server or application can call this callback function. + // The callback function can then implement logic to handle the connection attempt, such as logging the attempt, + // checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks. + // If the callback function returns an error, the connection attempt will be aborted. + OnConnectionAttempt func(net.Addr) error } func defaultConnectContextMaker() (context.Context, func()) { diff --git a/conn.go b/conn.go index e65163cf7..ee1061d40 100644 --- a/conn.go +++ b/conn.go @@ -324,6 +324,11 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, if config == nil { return nil, errNoConfigProvided } + if config.OnConnectionAttempt != nil { + if err := config.OnConnectionAttempt(rAddr); err != nil { + return nil, err + } + } dconn, err := createConn(conn, rAddr, config, false) if err != nil { return nil, err diff --git a/examples/listen/cid/main.go b/examples/listen/cid/main.go index 770bbcfa4..fb5887fc9 100644 --- a/examples/listen/cid/main.go +++ b/examples/listen/cid/main.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "net" + "sync" "time" "github.com/pion/dtls/v2" @@ -26,12 +27,44 @@ func main() { // Everything below is the pion-DTLS API! Thanks for using it ❤️. // + // *************** Variables only used to implement a basic Brute Force Attack protection *************** + var attempts = make(map[string]int) // Map of attempts for each IP address + var attemptsMutex sync.Mutex // Mutex for the map of attempts + var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes + // Prepare the configuration of the DTLS connection config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, + OnConnectionAttempt: func(addr net.Addr) error { + // *************** Brute Force Attack protection *************** + // Check if the IP address is in the map, and the IP address has exceeded the limit + attemptsMutex.Lock() + defer attemptsMutex.Unlock() + // Here I implement a time cleaner for the map of attempts, every 5 minutes I will decrement by 1 the number of attempts for each IP address + if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) { + attemptsCleaner = time.Now() + for k, v := range attempts { + if v > 0 { + attempts[k]-- + } + if attempts[k] == 0 { + delete(attempts, k) + } + } + } + // Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection) + attemptIP := addr.(*net.UDPAddr).IP.String() + if attempts[attemptIP] > 10 { + return fmt.Errorf("too many attempts from this IP address") + } + // Here I increment the number of attempts for this IP address (Brute Force Attack protection) + attempts[attemptIP]++ + // *************** END Brute Force Attack protection END *************** + return nil + }, PSKIdentityHint: []byte("Pion DTLS Server"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, @@ -65,6 +98,20 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // *************** Brute Force Attack protection *************** + // Here I decrease the number of attempts for this IP address + attemptsMutex.Lock() + attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String() + if attempts[attemptIP] > 0 { + attempts[attemptIP]-- + // If the number of attempts for this IP address is 0, I delete the IP address from the map + if attempts[attemptIP] == 0 { + delete(attempts, attemptIP) + } + } + attemptsMutex.Unlock() + // *************** END Brute Force Attack protection END *************** + // Register the connection with the chat hub if err == nil { hub.Register(conn) diff --git a/examples/listen/psk/main.go b/examples/listen/psk/main.go index 66f099693..3472cbfe8 100644 --- a/examples/listen/psk/main.go +++ b/examples/listen/psk/main.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "net" + "sync" "time" "github.com/pion/dtls/v2" @@ -26,12 +27,44 @@ func main() { // Everything below is the pion-DTLS API! Thanks for using it ❤️. // + // *************** Variables only used to implement a basic Brute Force Attack protection *************** + var attempts = make(map[string]int) // Map of attempts for each IP address + var attemptsMutex sync.Mutex // Mutex for the map of attempts + var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes + // Prepare the configuration of the DTLS connection config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, + OnConnectionAttempt: func(addr net.Addr) error { + // *************** Brute Force Attack protection *************** + // Check if the IP address is in the map, and the IP address has exceeded the limit + attemptsMutex.Lock() + defer attemptsMutex.Unlock() + // Here I implement a time cleaner for the map of attempts, every 5 minutes I will decrement by 1 the number of attempts for each IP address + if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) { + attemptsCleaner = time.Now() + for k, v := range attempts { + if v > 0 { + attempts[k]-- + } + if attempts[k] == 0 { + delete(attempts, k) + } + } + } + // Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection) + attemptIP := addr.(*net.UDPAddr).IP.String() + if attempts[attemptIP] > 10 { + return fmt.Errorf("too many attempts from this IP address") + } + // Here I increment the number of attempts for this IP address (Brute Force Attack protection) + attempts[attemptIP]++ + // *************** END Brute Force Attack protection END *************** + return nil + }, PSKIdentityHint: []byte("Pion DTLS Server"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, @@ -64,6 +97,20 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // *************** Brute Force Attack protection *************** + // Here I decrease the number of attempts for this IP address + attemptsMutex.Lock() + attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String() + if attempts[attemptIP] > 0 { + attempts[attemptIP]-- + // If the number of attempts for this IP address is 0, I delete the IP address from the map + if attempts[attemptIP] == 0 { + delete(attempts, attemptIP) + } + } + attemptsMutex.Unlock() + // *************** END Brute Force Attack protection END *************** + // Register the connection with the chat hub if err == nil { hub.Register(conn) diff --git a/examples/listen/verify/main.go b/examples/listen/verify/main.go index a02211e15..7eb928c67 100644 --- a/examples/listen/verify/main.go +++ b/examples/listen/verify/main.go @@ -10,6 +10,7 @@ import ( "crypto/x509" "fmt" "net" + "sync" "time" "github.com/pion/dtls/v2" @@ -28,6 +29,11 @@ func main() { // Everything below is the pion-DTLS API! Thanks for using it ❤️. // + // *************** Variables only used to implement a basic Brute Force Attack protection *************** + var attempts = make(map[string]int) // Map of attempts for each IP address + var attemptsMutex sync.Mutex // Mutex for the map of attempts + var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes + certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem", "examples/certificates/server.pub.pem") util.Check(err) @@ -49,6 +55,33 @@ func main() { ConnectContextMaker: func() (context.Context, func()) { return context.WithTimeout(ctx, 30*time.Second) }, + OnConnectionAttempt: func(addr net.Addr) error { + // *************** Brute Force Attack protection *************** + // Check if the IP address is in the map, and the IP address has exceeded the limit + attemptsMutex.Lock() + defer attemptsMutex.Unlock() + // Here I implement a time cleaner for the map of attempts, every 5 minutes I will decrement by 1 the number of attempts for each IP address + if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) { + attemptsCleaner = time.Now() + for k, v := range attempts { + if v > 0 { + attempts[k]-- + } + if attempts[k] == 0 { + delete(attempts, k) + } + } + } + // Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection) + attemptIP := addr.(*net.UDPAddr).IP.String() + if attempts[attemptIP] > 10 { + return fmt.Errorf("too many attempts from this IP address") + } + // Here I increment the number of attempts for this IP address (Brute Force Attack protection) + attempts[attemptIP]++ + // *************** END Brute Force Attack protection END *************** + return nil + }, } // Connect to a DTLS server @@ -74,6 +107,20 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // *************** Brute Force Attack protection *************** + // Here I decrease the number of attempts for this IP address + attemptsMutex.Lock() + attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String() + if attempts[attemptIP] > 0 { + attempts[attemptIP]-- + // If the number of attempts for this IP address is 0, I delete the IP address from the map + if attempts[attemptIP] == 0 { + delete(attempts, attemptIP) + } + } + attemptsMutex.Unlock() + // *************** END Brute Force Attack protection END *************** + // Register the connection with the chat hub hub.Register(conn) }