Skip to content

Commit

Permalink
enforce unique names for ext client names (#2476)
Browse files Browse the repository at this point in the history
* enforce unique names for ext client names

* only check for unique id on creation

* check for unique id if changed
  • Loading branch information
mattkasun committed Aug 14, 2023
1 parent 495f721 commit 2ad4653
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 68 deletions.
128 changes: 69 additions & 59 deletions controllers/ext_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/gravitl/netmaker/models/promodels"
"github.com/gravitl/netmaker/mq"
"github.com/skip2/go-qrcode"
"golang.org/x/exp/slog"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

Expand Down Expand Up @@ -308,31 +309,28 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")

var params = mux.Vars(r)
networkName := params["network"]
nodeid := params["nodeid"]

ingressExists := checkIngressExists(nodeid)
if !ingressExists {
err := errors.New("ingress does not exist")
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create extclient on network [%s]: %v", networkName, err))
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
slog.Error("failed to create extclient", "user", r.Header.Get("user"), "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}

var extclient models.ExtClient
var customExtClient models.CustomExtClient

if err := json.NewDecoder(r.Body).Decode(&customExtClient); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
if err := validateExtClient(&extclient, &customExtClient); err != nil {
if err := validateCustomExtClient(&customExtClient, true); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
extclient := logic.UpdateExtClient(&models.ExtClient{}, &customExtClient)

extclient.Network = networkName
extclient.IngressGatewayID = nodeid
node, err := logic.GetNodeByID(nodeid)
if err != nil {
Expand All @@ -341,6 +339,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
extclient.Network = node.Network
host, err := logic.GetHost(node.HostID.String())
if err != nil {
logger.Log(0, r.Header.Get("user"),
Expand All @@ -351,21 +350,19 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
listenPort := logic.GetPeerListenPort(host)
extclient.IngressGatewayEndpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), listenPort)
extclient.Enabled = true
parentNetwork, err := logic.GetNetwork(networkName)
parentNetwork, err := logic.GetNetwork(node.Network)
if err == nil { // check if parent network default ACL is enabled (yes) or not (no)
extclient.Enabled = parentNetwork.DefaultACL == "yes"
}

if err := logic.SetClientDefaultACLs(&extclient); err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to assign ACLs to new ext client on network [%s]: %v", networkName, err))
slog.Error("failed to set default acls for extclient", "user", r.Header.Get("user"), "network", node.Network, "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}

if err = logic.CreateExtClient(&extclient); err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to create new ext client on network [%s]: %v", networkName, err))
slog.Error("failed to create extclient", "user", r.Header.Get("user"), "network", node.Network, "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
Expand All @@ -374,13 +371,13 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("ismaster") != "yes" {
userID := r.Header.Get("user")
if isAdmin, err = checkProClientAccess(userID, extclient.ClientID, &parentNetwork); err != nil {
logger.Log(0, userID, "attempted to create a client on network", networkName, "but they lack access")
logic.DeleteExtClient(networkName, extclient.ClientID)
slog.Error("pro client access check failed", "user", userID, "network", node.Network, "error", err)
logic.DeleteExtClient(node.Network, extclient.ClientID)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if !isAdmin {
if err = pro.AssociateNetworkUserClient(userID, networkName, extclient.ClientID); err != nil {
if err = pro.AssociateNetworkUserClient(userID, node.Network, extclient.ClientID); err != nil {
logger.Log(0, "failed to associate client", extclient.ClientID, "to user", userID)
}
extclient.OwnerID = userID
Expand All @@ -390,7 +387,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
}
}

logger.Log(0, r.Header.Get("user"), "created new ext client on network", networkName)
slog.Info("created extclient", "user", r.Header.Get("user"), "network", node.Network, "clientid", extclient.ClientID)
w.WriteHeader(http.StatusOK)
go func() {
if err := mq.PublishPeerUpdate(); err != nil {
Expand Down Expand Up @@ -419,7 +416,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)

var update models.CustomExtClient
var oldExtClient models.ExtClient
//var oldExtClient models.ExtClient
var sendPeerUpdate bool
err := json.NewDecoder(r.Body).Decode(&update)
if err != nil {
Expand All @@ -429,50 +426,40 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
return
}
clientid := params["clientid"]
network := params["network"]
key, err := logic.GetRecordKey(clientid, network)
oldExtClient, err := logic.GetExtClientByName(clientid)
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get record key for client [%s], network [%s]: %v",
clientid, network, err))
slog.Error("failed to retrieve extclient", "user", r.Header.Get("user"), "id", clientid, "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if err := validateExtClient(&oldExtClient, &update); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ext client record key [%s] from db for client [%s], network [%s]: %v",
key, clientid, network, err))
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if err = json.Unmarshal([]byte(data), &oldExtClient); err != nil {
logger.Log(0, "error unmarshalling extclient: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
if oldExtClient.ClientID == update.ClientID {
if err := validateCustomExtClient(&update, false); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
} else {
if err := validateCustomExtClient(&update, true); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
}

// == PRO ==
networkName := params["network"]
//networkName := params["network"]
var changedID = update.ClientID != oldExtClient.ClientID
if r.Header.Get("ismaster") != "yes" {
userID := r.Header.Get("user")
_, doesOwn := doesUserOwnClient(userID, params["clientid"], networkName)
_, doesOwn := doesUserOwnClient(userID, params["clientid"], oldExtClient.Network)
if !doesOwn {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal"))
return
}
}
if changedID && oldExtClient.OwnerID != "" {
if err := pro.DissociateNetworkUserClient(oldExtClient.OwnerID, networkName, oldExtClient.ClientID); err != nil {
if err := pro.DissociateNetworkUserClient(oldExtClient.OwnerID, oldExtClient.Network, oldExtClient.ClientID); err != nil {
logger.Log(0, "failed to dissociate client", oldExtClient.ClientID, "from user", oldExtClient.OwnerID)
}
if err := pro.AssociateNetworkUserClient(oldExtClient.OwnerID, networkName, update.ClientID); err != nil {
if err := pro.AssociateNetworkUserClient(oldExtClient.OwnerID, oldExtClient.Network, update.ClientID); err != nil {
logger.Log(0, "failed to associate client", update.ClientID, "to user", oldExtClient.OwnerID)
}
}
Expand All @@ -485,13 +472,15 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
if update.Enabled != oldExtClient.Enabled {
sendPeerUpdate = true
}
// extra var need as logic.Update changes oldExtClient
currentClient := oldExtClient
newclient, err := logic.UpdateExtClient(&oldExtClient, &update)
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to update ext client [%s], network [%s]: %v",
clientid, network, err))
newclient := logic.UpdateExtClient(&oldExtClient, &update)
if err := logic.DeleteExtClient(oldExtClient.Network, oldExtClient.ClientID); err != nil {

slog.Error("failed to delete ext client", "user", r.Header.Get("user"), "id", oldExtClient.ClientID, "network", oldExtClient.Network, "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if err := logic.SaveExtClient(&newclient); err != nil {
slog.Error("failed to save ext client", "user", r.Header.Get("user"), "id", newclient.ClientID, "network", newclient.Network, "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
Expand All @@ -507,7 +496,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(newclient)
if changedID {
go func() {
if err := mq.PublishExtClientDNSUpdate(currentClient, *newclient, networkName); err != nil {
if err := mq.PublishExtClientDNSUpdate(oldExtClient, newclient, oldExtClient.Network); err != nil {
logger.Log(1, "error pubishing dns update for extcient update", err.Error())
}
}()
Expand Down Expand Up @@ -647,18 +636,20 @@ func doesUserOwnClient(username, clientID, network string) (bool, bool) {
return false, logic.StringSliceContains(netUser.Clients, clientID)
}

// validateExtClient Validates the extclient object
func validateExtClient(extclient *models.ExtClient, customExtClient *models.CustomExtClient) error {
// validateCustomExtClient Validates the extclient object
func validateCustomExtClient(customExtClient *models.CustomExtClient, checkID bool) error {
//validate clientid
if customExtClient.ClientID != "" && !validName(customExtClient.ClientID) {
return errInvalidExtClientID
if customExtClient.ClientID != "" {
if err := isValid(customExtClient.ClientID, checkID); err != nil {
return fmt.Errorf("client validatation: %v", err)
}
}
extclient.ClientID = customExtClient.ClientID
//extclient.ClientID = customExtClient.ClientID
if len(customExtClient.PublicKey) > 0 {
if _, err := wgtypes.ParseKey(customExtClient.PublicKey); err != nil {
return errInvalidExtClientPubKey
}
extclient.PublicKey = customExtClient.PublicKey
//extclient.PublicKey = customExtClient.PublicKey
}
//validate extra ips
if len(customExtClient.ExtraAllowedIPs) > 0 {
Expand All @@ -667,14 +658,33 @@ func validateExtClient(extclient *models.ExtClient, customExtClient *models.Cust
return errInvalidExtClientExtraIP
}
}
extclient.ExtraAllowedIPs = customExtClient.ExtraAllowedIPs
//extclient.ExtraAllowedIPs = customExtClient.ExtraAllowedIPs
}
//validate DNS
if customExtClient.DNS != "" {
if ip := net.ParseIP(customExtClient.DNS); ip == nil {
return errInvalidExtClientDNS
}
extclient.DNS = customExtClient.DNS
//extclient.DNS = customExtClient.DNS
}
return nil
}

// isValid Checks if the clientid is valid
func isValid(clientid string, checkID bool) error {
if !validName(clientid) {
return errInvalidExtClientID
}
if checkID {
extclients, err := logic.GetAllExtClients()
if err != nil {
return fmt.Errorf("extclients isValid: %v", err)
}
for _, extclient := range extclients {
if clientid == extclient.ClientID {
return errDuplicateExtClientName
}
}
}
return nil
}
1 change: 1 addition & 0 deletions controllers/regex.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var (
errInvalidExtClientID = errors.New("ext client ID must be alphanumderic and/or dashes and less that 15 chars")
errInvalidExtClientExtraIP = errors.New("ext client extra ip must be a valid cidr")
errInvalidExtClientDNS = errors.New("ext client dns must be a valid ip address")
errDuplicateExtClientName = errors.New("duplicate client name")
)

// allow only dashes and alphaneumeric for ext client and node names
Expand Down
15 changes: 15 additions & 0 deletions logic/clients.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package logic

import (
"errors"
"sort"

"github.com/gravitl/netmaker/models"
Expand Down Expand Up @@ -70,3 +71,17 @@ func SortExtClient(unsortedExtClient []models.ExtClient) {
return unsortedExtClient[i].ClientID < unsortedExtClient[j].ClientID
})
}

// GetExtClientByName - gets an ext client by name
func GetExtClientByName(ID string) (models.ExtClient, error) {
clients, err := GetAllExtClients()
if err != nil {
return models.ExtClient{}, err
}
for i := range clients {
if clients[i].ClientID == ID {
return clients[i], nil
}
}
return models.ExtClient{}, errors.New("client not found")
}
14 changes: 5 additions & 9 deletions logic/extpeers.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ func GetExtClientByPubKey(publicKey string, network string) (*models.ExtClient,
return nil, fmt.Errorf("no client found")
}

// CreateExtClient - creates an extclient
// CreateExtClient - creates and saves an extclient
func CreateExtClient(extclient *models.ExtClient) error {
// lock because we need unique IPs and having it concurrent makes parallel calls result in same "unique" IPs
// lock because we may need unique IPs and having it concurrent makes parallel calls result in same "unique" IPs
addressLock.Lock()
defer addressLock.Unlock()

Expand Down Expand Up @@ -219,12 +219,8 @@ func SaveExtClient(extclient *models.ExtClient) error {
}

// UpdateExtClient - updates an ext client with new values
func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) (*models.ExtClient, error) {
new := old
err := DeleteExtClient(old.Network, old.ClientID)
if err != nil {
return new, err
}
func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) models.ExtClient {
new := *old
new.ClientID = update.ClientID
if update.PublicKey != "" && old.PublicKey != update.PublicKey {
new.PublicKey = update.PublicKey
Expand All @@ -241,7 +237,7 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) (*mo
if update.DeniedACLs != nil && !reflect.DeepEqual(old.DeniedACLs, update.DeniedACLs) {
new.DeniedACLs = update.DeniedACLs
}
return new, CreateExtClient(new)
return new
}

// GetExtClientsByID - gets the clients of attached gateway
Expand Down

0 comments on commit 2ad4653

Please sign in to comment.