Skip to content

Commit

Permalink
[NET-562] Persistent Keep Alive from node to host (#2604)
Browse files Browse the repository at this point in the history
* Move PKA field from models node to host level

* Move PKA field from api models node to host level

* Adapt logic package to node->host PKA

* Adapt migration-related code to node->host PKA

* Adapt cli code to node->host PKA

* Change host PKA default to 20s

* On IfaceDelta, check for PKA on host

* On handleHostRegister, set default PKA

* Use a default PKA

* Use int64 for api host pka

* Reorder imports

* Don't use host pka in iface delta

* Fix ConvertAPIHostToNMHost

* Add swagger doc for host PKA field

* Fix swagger.yml

* Set default PKA only for new hosts

* Remove TODO comment

* Remove redundant check

* Have api-host pka be specified in seconds
  • Loading branch information
gabrielseibel1 authored Oct 6, 2023
1 parent 234f226 commit cb4b99f
Show file tree
Hide file tree
Showing 14 changed files with 179 additions and 119 deletions.
6 changes: 5 additions & 1 deletion cli/cmd/host/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import (
"log"
"os"

"github.com/spf13/cobra"

"github.com/gravitl/netmaker/cli/functions"
"github.com/gravitl/netmaker/models"
"github.com/spf13/cobra"
)

var (
Expand All @@ -18,6 +19,7 @@ var (
mtu int
isStatic bool
isDefault bool
keepAlive int
)

var hostUpdateCmd = &cobra.Command{
Expand All @@ -43,6 +45,7 @@ var hostUpdateCmd = &cobra.Command{
apiHost.MTU = mtu
apiHost.IsStatic = isStatic
apiHost.IsDefault = isDefault
apiHost.PersistentKeepalive = keepAlive
}
functions.PrettyPrint(functions.UpdateHost(args[0], apiHost))
},
Expand All @@ -54,6 +57,7 @@ func init() {
hostUpdateCmd.Flags().StringVar(&name, "name", "", "Host name")
hostUpdateCmd.Flags().IntVar(&listenPort, "listen_port", 0, "Listen port of the host")
hostUpdateCmd.Flags().IntVar(&mtu, "mtu", 0, "Host MTU size")
hostUpdateCmd.Flags().IntVar(&keepAlive, "keep_alive", 0, "Interval (seconds) in which packets are sent to keep connections open with peers")
hostUpdateCmd.Flags().BoolVar(&isStatic, "static", false, "Make Host Static ?")
hostUpdateCmd.Flags().BoolVar(&isDefault, "default", false, "Make Host Default ?")
rootCmd.AddCommand(hostUpdateCmd)
Expand Down
1 change: 0 additions & 1 deletion cli/cmd/node/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ var (
name string
postUp string
postDown string
keepAlive int
relayedNodes string
egressGatewayRanges string
expirationDateTime int
Expand Down
2 changes: 0 additions & 2 deletions cli/cmd/node/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ var nodeUpdateCmd = &cobra.Command{
node.Address = address
node.Address6 = address6
node.LocalAddress = localAddress
node.PersistentKeepalive = int32(keepAlive)
if relayedNodes != "" {
node.RelayedNodes = strings.Split(relayedNodes, ",")
}
Expand All @@ -61,7 +60,6 @@ func init() {
nodeUpdateCmd.Flags().StringVar(&name, "name", "", "Node name")
nodeUpdateCmd.Flags().StringVar(&postUp, "post_up", "", "Commands to run after node is up `;` separated")
nodeUpdateCmd.Flags().StringVar(&postDown, "post_down", "", "Commands to run after node is down `;` separated")
nodeUpdateCmd.Flags().IntVar(&keepAlive, "keep_alive", 0, "Interval in which packets are sent to keep connections open with peers")
nodeUpdateCmd.Flags().StringVar(&relayedNodes, "relayed_nodes", "", "relayed nodes if node acts as a relay")
nodeUpdateCmd.Flags().StringVar(&egressGatewayRanges, "egress_addrs", "", "Addresses for egressing traffic if node acts as an egress")
nodeUpdateCmd.Flags().IntVar(&expirationDateTime, "expiry", 0, "UNIX timestamp after which node will lose access to the network")
Expand Down
60 changes: 46 additions & 14 deletions controllers/enrollmentkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/gorilla/mux"

"github.com/gravitl/netmaker/auth"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
Expand All @@ -17,10 +18,14 @@ import (
)

func enrollmentKeyHandlers(r *mux.Router) {
r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).Methods(http.MethodPost)
r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete)
r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).Methods(http.MethodPost)
r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).
Methods(http.MethodPost)
r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).
Methods(http.MethodGet)
r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).
Methods(http.MethodDelete)
r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).
Methods(http.MethodPost)
}

// swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys
Expand Down Expand Up @@ -70,7 +75,7 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) {
// Responses:
// 200: okResponse
func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
params := mux.Vars(r)
keyID := params["keyID"]
err := logic.DeleteEnrollmentKey(keyID)
if err != nil {
Expand All @@ -94,7 +99,6 @@ func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) {
// Responses:
// 200: EnrollmentKey
func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {

var enrollmentKeyBody models.APIEnrollmentKey

err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody)
Expand All @@ -109,7 +113,13 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
}

newEnrollmentKey, err := logic.CreateEnrollmentKey(enrollmentKeyBody.UsesRemaining, newTime, enrollmentKeyBody.Networks, enrollmentKeyBody.Tags, enrollmentKeyBody.Unlimited)
newEnrollmentKey, err := logic.CreateEnrollmentKey(
enrollmentKeyBody.UsesRemaining,
newTime,
enrollmentKeyBody.Networks,
enrollmentKeyBody.Tags,
enrollmentKeyBody.Unlimited,
)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
Expand Down Expand Up @@ -138,7 +148,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
// Responses:
// 200: RegisterResponse
func handleHostRegister(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
params := mux.Vars(r)
token := params["token"]
logger.Log(0, "received registration attempt with token", token)
// check if token exists
Expand All @@ -156,7 +166,6 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
hostExists := false
// re-register host with turn just in case.
if servercfg.IsUsingTurn() {
err = logic.RegisterHostWithTurn(newHost.ID.String(), newHost.HostPass)
Expand All @@ -165,9 +174,20 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
}
}
// check if host already exists
hostExists := false
if hostExists = logic.HostExists(&newHost); hostExists && len(enrollmentKey.Networks) == 0 {
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register with no networks")
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("host already exists"), "badrequest"))
logger.Log(
0,
"host",
newHost.ID.String(),
newHost.Name,
"attempted to re-register with no networks",
)
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(fmt.Errorf("host already exists"), "badrequest"),
)
return
}
// version check
Expand All @@ -190,11 +210,16 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
// use the token
if ok := logic.TryToUseEnrollmentKey(enrollmentKey); !ok {
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration")
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest"))
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest"),
)
return
}
hostPass := newHost.HostPass
if !hostExists {
newHost.PersistentKeepalive = models.DefaultPersistentKeepAlive
// register host
logic.CheckHostPorts(&newHost)
// create EMQX credentials and ACLs for host
Expand All @@ -209,14 +234,21 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
}
}
if err = logic.CreateHost(&newHost); err != nil {
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error())
logger.Log(
0,
"host",
newHost.ID.String(),
newHost.Name,
"failed registration -",
err.Error(),
)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
} else {
// need to revise the list of networks from key
// based on the ones host currently has
var networksToAdd = []string{}
networksToAdd := []string{}
currentNets := logic.GetHostNetworks(newHost.ID.String())
for _, newNet := range enrollmentKey.Networks {
if !logic.StringSliceContains(currentNets, newNet) {
Expand Down
2 changes: 1 addition & 1 deletion controllers/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func migrate(w http.ResponseWriter, r *http.Request) {
host.Name = data.HostName
host.HostPass = data.Password
host.OS = data.OS
host.PersistentKeepalive = time.Duration(legacy.PersistentKeepalive)
if err := logic.CreateHost(&host); err != nil {
slog.Error("create host", "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
Expand Down Expand Up @@ -202,7 +203,6 @@ func convertLegacyNode(legacy models.LegacyNode, hostID uuid.UUID) models.Node {
node.IsRelay = false
node.RelayedNodes = []string{}
node.DNSOn = models.ParseBool(legacy.DNSOn)
node.PersistentKeepalive = time.Duration(int64(time.Second) * int64(legacy.PersistentKeepalive))
node.LastModified = time.Now()
node.ExpirationDateTime = time.Unix(legacy.ExpirationDateTime, 0)
node.EgressGatewayNatEnabled = models.ParseBool(legacy.EgressGatewayNatEnabled)
Expand Down
16 changes: 9 additions & 7 deletions logic/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import (

"github.com/devilcove/httpclient"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"

"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/bcrypt"
)

var (
Expand Down Expand Up @@ -66,6 +67,7 @@ func deleteHostFromCache(hostID string) {
delete(hostsCacheMap, hostID)
hostCacheMutex.Unlock()
}

func loadHostsIntoCache(hMap map[string]models.Host) {
hostCacheMutex.Lock()
hostsCacheMap = hMap
Expand All @@ -79,7 +81,6 @@ const (

// GetAllHosts - returns all hosts in flat list or error
func GetAllHosts() ([]models.Host, error) {

currHosts := getHostsFromCache()
if len(currHosts) != 0 {
return currHosts, nil
Expand Down Expand Up @@ -139,7 +140,6 @@ func GetHostsMap() (map[string]models.Host, error) {

// GetHost - gets a host from db given id
func GetHost(hostid string) (*models.Host, error) {

if host, ok := getHostFromCache(hostid); ok {
return &host, nil
}
Expand Down Expand Up @@ -217,11 +217,13 @@ func UpdateHost(newHost, currentHost *models.Host) {
newHost.ListenPort = currentHost.ListenPort
}

if newHost.PersistentKeepalive == 0 {
newHost.PersistentKeepalive = currentHost.PersistentKeepalive
}
}

// UpdateHostFromClient - used for updating host on server with update recieved from client
func UpdateHostFromClient(newHost, currHost *models.Host) (sendPeerUpdate bool) {

if newHost.PublicKey != currHost.PublicKey {
currHost.PublicKey = newHost.PublicKey
sendPeerUpdate = true
Expand All @@ -230,7 +232,8 @@ func UpdateHostFromClient(newHost, currHost *models.Host) (sendPeerUpdate bool)
currHost.ListenPort = newHost.ListenPort
sendPeerUpdate = true
}
if newHost.WgPublicListenPort != 0 && currHost.WgPublicListenPort != newHost.WgPublicListenPort {
if newHost.WgPublicListenPort != 0 &&
currHost.WgPublicListenPort != newHost.WgPublicListenPort {
currHost.WgPublicListenPort = newHost.WgPublicListenPort
sendPeerUpdate = true
}
Expand Down Expand Up @@ -488,7 +491,7 @@ func CheckHostPorts(h *models.Host) {
}
for _, host := range hosts {
if host.ID.String() == h.ID.String() {
//skip self
// skip self
continue
}
if !host.EndpointIP.Equal(h.EndpointIP) {
Expand All @@ -503,7 +506,6 @@ func CheckHostPorts(h *models.Host) {
h.ListenPort = minPort
}
}

}

// HostExists - checks if given host already exists
Expand Down
3 changes: 0 additions & 3 deletions logic/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,6 @@ func SetNodeDefaults(node *models.Node) {
node.DefaultACL = parentNetwork.DefaultACL
}

if node.PersistentKeepalive == 0 {
node.PersistentKeepalive = time.Second * time.Duration(parentNetwork.DefaultKeepalive)
}
node.SetLastModified()
node.SetLastCheckIn()
node.SetDefaultConnected()
Expand Down
10 changes: 5 additions & 5 deletions logic/peers.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
}
relayPeer := wgtypes.PeerConfig{
PublicKey: relayHost.PublicKey,
PersistentKeepaliveInterval: &relayNode.PersistentKeepalive,
PersistentKeepaliveInterval: &relayHost.PersistentKeepalive,
ReplaceAllowedIPs: true,
AllowedIPs: GetAllowedIPs(&node, &relayNode, nil),
}
Expand Down Expand Up @@ -111,7 +111,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
peer := peer
if peer.ID.String() == node.ID.String() {
logger.Log(2, "peer update, skipping self")
//skip yourself
// skip yourself
continue
}

Expand All @@ -122,7 +122,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
}
peerConfig := wgtypes.PeerConfig{
PublicKey: peerHost.PublicKey,
PersistentKeepaliveInterval: &peer.PersistentKeepalive,
PersistentKeepaliveInterval: &peerHost.PersistentKeepalive,
ReplaceAllowedIPs: true,
}
if peer.IsEgressGateway {
Expand Down Expand Up @@ -390,7 +390,7 @@ func GetEgressIPs(peer *models.Node) []net.IPNet {
logger.Log(0, "error retrieving host for peer", peer.ID.String(), err.Error())
}

//check for internet gateway
// check for internet gateway
internetGateway := false
if slices.Contains(peer.EgressGatewayRanges, "0.0.0.0/0") || slices.Contains(peer.EgressGatewayRanges, "::/0") {
internetGateway = true
Expand Down Expand Up @@ -439,7 +439,7 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
}
// handle egress gateway peers
if peer.IsEgressGateway {
//hasGateway = true
// hasGateway = true
egressIPs := GetEgressIPs(peer)
allowedips = append(allowedips, egressIPs...)
}
Expand Down
1 change: 0 additions & 1 deletion logic/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ func IfaceDelta(currentNode *models.Node, newNode *models.Node) bool {
newNode.IsEgressGateway != currentNode.IsEgressGateway ||
newNode.IsIngressGateway != currentNode.IsIngressGateway ||
newNode.IsRelay != currentNode.IsRelay ||
newNode.PersistentKeepalive != currentNode.PersistentKeepalive ||
newNode.DNSOn != currentNode.DNSOn ||
newNode.Connected != currentNode.Connected {
return true
Expand Down
Loading

0 comments on commit cb4b99f

Please sign in to comment.