Skip to content

Commit

Permalink
- Added in the Config structure a OnConnectionAttempt callback
Browse files Browse the repository at this point in the history
- Added examples to show how it can be used in a server application.
  • Loading branch information
tonisole committed Jun 4, 2024
1 parent edc7ad0 commit 6713f34
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 0 deletions.
7 changes: 7 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"

"github.com/pion/dtls/v2/pkg/crypto/elliptic"
Expand Down Expand Up @@ -214,6 +215,12 @@ type Config struct {
// CertificateRequestMessageHook, if not nil, is called when a Certificate Request
// message is sent from a server. The returned handshake message replaces the original message.
CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message

// Whenever a connection attempt is made, the server or application can call this callback function.
// The callback function can then implement logic to handle the connection attempt, such as logging the attempt,
// checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks.
// If the callback function returns an error, the connection attempt will be aborted.
OnConnectionAttempt func(net.Addr) error
}

func defaultConnectContextMaker() (context.Context, func()) {
Expand Down
5 changes: 5 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
if config == nil {
return nil, errNoConfigProvided
}
if config.OnConnectionAttempt != nil {
if err := config.OnConnectionAttempt(rAddr); err != nil {
return nil, err
}
}
dconn, err := createConn(conn, rAddr, config, false)
if err != nil {
return nil, err
Expand Down
47 changes: 47 additions & 0 deletions examples/listen/cid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"fmt"
"net"
"sync"
"time"

"github.com/pion/dtls/v2"
Expand All @@ -26,12 +27,44 @@ func main() {
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//

// *************** Variables only used to implement a basic Brute Force Attack protection ***************
var attempts = make(map[string]int) // Map of attempts for each IP address
var attemptsMutex sync.Mutex // Mutex for the map of attempts
var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
fmt.Printf("Client's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
OnConnectionAttempt: func(addr net.Addr) error {
// *************** Brute Force Attack protection ***************
// Check if the IP address is in the map, and the IP address has exceeded the limit
attemptsMutex.Lock()
defer attemptsMutex.Unlock()
// Here I implement a time cleaner for the map of attempts, every 5 minutes I will decrement by 1 the number of attempts for each IP address
if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) {
attemptsCleaner = time.Now()
for k, v := range attempts {
if v > 0 {
attempts[k]--
}
if attempts[k] == 0 {
delete(attempts, k)
}
}
}
// Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection)
attemptIP := addr.(*net.UDPAddr).IP.String()
if attempts[attemptIP] > 10 {
return fmt.Errorf("too many attempts from this IP address")
}
// Here I increment the number of attempts for this IP address (Brute Force Attack protection)
attempts[attemptIP]++
// *************** END Brute Force Attack protection END ***************
return nil
},
PSKIdentityHint: []byte("Pion DTLS Server"),
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
Expand Down Expand Up @@ -65,6 +98,20 @@ func main() {
// using `dtlsConn := conn.(*dtls.Conn)` in order to to expose
// functions like `ConnectionState` etc.

// *************** Brute Force Attack protection ***************
// Here I decrease the number of attempts for this IP address
attemptsMutex.Lock()
attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String()
if attempts[attemptIP] > 0 {
attempts[attemptIP]--
// If the number of attempts for this IP address is 0, I delete the IP address from the map
if attempts[attemptIP] == 0 {
delete(attempts, attemptIP)
}
}
attemptsMutex.Unlock()
// *************** END Brute Force Attack protection END ***************

// Register the connection with the chat hub
if err == nil {
hub.Register(conn)
Expand Down
47 changes: 47 additions & 0 deletions examples/listen/psk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"fmt"
"net"
"sync"
"time"

"github.com/pion/dtls/v2"
Expand All @@ -26,12 +27,44 @@ func main() {
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//

// *************** Variables only used to implement a basic Brute Force Attack protection ***************
var attempts = make(map[string]int) // Map of attempts for each IP address
var attemptsMutex sync.Mutex // Mutex for the map of attempts
var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
fmt.Printf("Client's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
OnConnectionAttempt: func(addr net.Addr) error {
// *************** Brute Force Attack protection ***************
// Check if the IP address is in the map, and the IP address has exceeded the limit
attemptsMutex.Lock()
defer attemptsMutex.Unlock()
// Here I implement a time cleaner for the map of attempts, every 5 minutes I will decrement by 1 the number of attempts for each IP address
if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) {
attemptsCleaner = time.Now()
for k, v := range attempts {
if v > 0 {
attempts[k]--
}
if attempts[k] == 0 {
delete(attempts, k)
}
}
}
// Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection)
attemptIP := addr.(*net.UDPAddr).IP.String()
if attempts[attemptIP] > 10 {
return fmt.Errorf("too many attempts from this IP address")
}
// Here I increment the number of attempts for this IP address (Brute Force Attack protection)
attempts[attemptIP]++
// *************** END Brute Force Attack protection END ***************
return nil
},
PSKIdentityHint: []byte("Pion DTLS Server"),
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
Expand Down Expand Up @@ -64,6 +97,20 @@ func main() {
// using `dtlsConn := conn.(*dtls.Conn)` in order to to expose
// functions like `ConnectionState` etc.

// *************** Brute Force Attack protection ***************
// Here I decrease the number of attempts for this IP address
attemptsMutex.Lock()
attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String()
if attempts[attemptIP] > 0 {
attempts[attemptIP]--
// If the number of attempts for this IP address is 0, I delete the IP address from the map
if attempts[attemptIP] == 0 {
delete(attempts, attemptIP)
}
}
attemptsMutex.Unlock()
// *************** END Brute Force Attack protection END ***************

// Register the connection with the chat hub
if err == nil {
hub.Register(conn)
Expand Down
47 changes: 47 additions & 0 deletions examples/listen/verify/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/x509"
"fmt"
"net"
"sync"
"time"

"github.com/pion/dtls/v2"
Expand All @@ -28,6 +29,11 @@ func main() {
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//

// *************** Variables only used to implement a basic Brute Force Attack protection ***************
var attempts = make(map[string]int) // Map of attempts for each IP address
var attemptsMutex sync.Mutex // Mutex for the map of attempts
var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes

certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem",
"examples/certificates/server.pub.pem")
util.Check(err)
Expand All @@ -49,6 +55,33 @@ func main() {
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(ctx, 30*time.Second)
},
OnConnectionAttempt: func(addr net.Addr) error {
// *************** Brute Force Attack protection ***************
// Check if the IP address is in the map, and the IP address has exceeded the limit
attemptsMutex.Lock()
defer attemptsMutex.Unlock()
// Here I implement a time cleaner for the map of attempts, every 5 minutes I will decrement by 1 the number of attempts for each IP address
if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) {
attemptsCleaner = time.Now()
for k, v := range attempts {
if v > 0 {
attempts[k]--
}
if attempts[k] == 0 {
delete(attempts, k)
}
}
}
// Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection)
attemptIP := addr.(*net.UDPAddr).IP.String()
if attempts[attemptIP] > 10 {
return fmt.Errorf("too many attempts from this IP address")
}
// Here I increment the number of attempts for this IP address (Brute Force Attack protection)
attempts[attemptIP]++
// *************** END Brute Force Attack protection END ***************
return nil
},
}

// Connect to a DTLS server
Expand All @@ -74,6 +107,20 @@ func main() {
// using `dtlsConn := conn.(*dtls.Conn)` in order to to expose
// functions like `ConnectionState` etc.

// *************** Brute Force Attack protection ***************
// Here I decrease the number of attempts for this IP address
attemptsMutex.Lock()
attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String()
if attempts[attemptIP] > 0 {
attempts[attemptIP]--
// If the number of attempts for this IP address is 0, I delete the IP address from the map
if attempts[attemptIP] == 0 {
delete(attempts, attemptIP)
}
}
attemptsMutex.Unlock()
// *************** END Brute Force Attack protection END ***************

// Register the connection with the chat hub
hub.Register(conn)
}
Expand Down

0 comments on commit 6713f34

Please sign in to comment.