Skip to content

Commit

Permalink
NodeID type and simplify policy compile call
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Feb 29, 2024
1 parent ea7f809 commit 3985af2
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 256 deletions.
6 changes: 1 addition & 5 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ type Headscale struct {

registrationCache *cache.Cache

shutdownChan chan struct{}
pollNetMapStreamWG sync.WaitGroup
}

Expand Down Expand Up @@ -504,7 +503,7 @@ func (h *Headscale) Serve() error {

// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.DERPMap, h.cfg)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap())

if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server
Expand Down Expand Up @@ -745,7 +744,6 @@ func (h *Headscale) Serve() error {
}

// Handle common process-killing signals so we can gracefully shut down:
h.shutdownChan = make(chan struct{})
sigc := make(chan os.Signal, 1)
signal.Notify(sigc,
syscall.SIGHUP,
Expand Down Expand Up @@ -788,8 +786,6 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully")

close(h.shutdownChan)

h.pollNetMapStreamWG.Wait()

// Gracefully shut down servers
Expand Down
29 changes: 12 additions & 17 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,22 @@ var (
)
)

func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) {
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListPeers(rx, node)
return ListPeers(rx, nodeID)
})
}

// ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) {
log.Trace().
Caller().
Str("node", node.Hostname).
Msg("Finding direct peers")

func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("node_key <> ?",
node.NodeKey.String()).Find(&nodes).Error; err != nil {
Where("id <> ?",
nodeID).Find(&nodes).Error; err != nil {
return types.Nodes{}, err
}

Expand Down Expand Up @@ -119,14 +114,14 @@ func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
return nil, ErrNodeNotFound
}

func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) {
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByID(rx, id)
})
}

// GetNodeByID finds a Node by ID and returns the Node struct.
func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) {
func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
mach := types.Node{}
if result := tx.
Preload("AuthKey").
Expand Down Expand Up @@ -197,7 +192,7 @@ func GetNodeByAnyKey(
}

func (hsdb *HSDatabase) SetTags(
nodeID uint64,
nodeID types.NodeID,
tags []string,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
Expand All @@ -208,7 +203,7 @@ func (hsdb *HSDatabase) SetTags(
// SetTags takes a Node struct pointer and update the forced tags.
func SetTags(
tx *gorm.DB,
nodeID uint64,
nodeID types.NodeID,
tags []string,
) error {
if len(tags) == 0 {
Expand Down Expand Up @@ -256,15 +251,15 @@ func RenameNode(tx *gorm.DB,
return nil
}

func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error {
return hsdb.Write(func(tx *gorm.DB) error {
return NodeSetExpiry(tx, nodeID, expiry)
})
}

// NodeSetExpiry takes a Node struct and a new expiry time.
func NodeSetExpiry(tx *gorm.DB,
nodeID uint64, expiry time.Time,
nodeID types.NodeID, expiry time.Time,
) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
}
Expand Down Expand Up @@ -296,7 +291,7 @@ func DeleteNode(tx *gorm.DB,

// UpdateLastSeen sets a node's last seen field indicating that we
// have recently communicating with this node.
func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error {
func UpdateLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
}

Expand Down
14 changes: 7 additions & 7 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (s *Suite) TestListPeers(c *check.C) {
machineKey := key.NewMachine()

node := types.Node{
ID: uint64(index),
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode" + strconv.Itoa(index),
Expand All @@ -156,7 +156,7 @@ func (s *Suite) TestListPeers(c *check.C) {
node0ByID, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil)

peersOfNode0, err := db.ListPeers(node0ByID)
peersOfNode0, err := db.ListPeers(node0ByID.ID)
c.Assert(err, check.IsNil)

c.Assert(len(peersOfNode0), check.Equals, 9)
Expand Down Expand Up @@ -189,7 +189,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
machineKey := key.NewMachine()

node := types.Node{
ID: uint64(index),
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPAddresses: types.NodeAddresses{
Expand Down Expand Up @@ -232,16 +232,16 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User)
c.Assert(err, check.IsNil)

adminPeers, err := db.ListPeers(adminNode)
adminPeers, err := db.ListPeers(adminNode.ID)
c.Assert(err, check.IsNil)

testPeers, err := db.ListPeers(testNode)
testPeers, err := db.ListPeers(testNode.ID)
c.Assert(err, check.IsNil)

adminRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, adminNode, adminPeers)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
c.Assert(err, check.IsNil)

testRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, testNode, testPeers)
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
c.Assert(err, check.IsNil)

peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
Expand Down
6 changes: 3 additions & 3 deletions hscontrol/db/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
for prefix, exists := range advertisedRoutes {
if !exists {
route := types.Route{
NodeID: node.ID,
NodeID: node.ID.Uint64(),
Prefix: types.IPPrefix(prefix),
Advertised: true,
Enabled: false,
Expand Down Expand Up @@ -641,7 +641,7 @@ func EnableAutoApprovedRoutes(
if err != nil {
log.Err(err).
Str("advertisedRoute", advertisedRoute.String()).
Uint64("nodeId", node.ID).
Uint64("nodeId", node.ID.Uint64()).
Msg("Failed to resolve autoApprovers for advertised route")

return nil, err
Expand Down Expand Up @@ -687,7 +687,7 @@ func EnableAutoApprovedRoutes(
if err != nil {
log.Err(err).
Str("approvedRoute", approvedRoute.String()).
Uint64("nodeId", node.ID).
Uint64("nodeId", node.ID.Uint64()).
Msg("Failed to enable approved route")

return nil, err
Expand Down
18 changes: 9 additions & 9 deletions hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (api headscaleV1APIServer) GetNode(
ctx context.Context,
request *v1.GetNodeRequest,
) (*v1.GetNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
}
Expand All @@ -248,12 +248,12 @@ func (api headscaleV1APIServer) SetTags(
}

node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetTags(tx, request.GetNodeId(), request.GetTags())
err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags())
if err != nil {
return nil, err
}

return db.GetNodeByID(tx, request.GetNodeId())
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
})
if err != nil {
return &v1.SetTagsResponse{
Expand Down Expand Up @@ -296,7 +296,7 @@ func (api headscaleV1APIServer) DeleteNode(
ctx context.Context,
request *v1.DeleteNodeRequest,
) (*v1.DeleteNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -330,11 +330,11 @@ func (api headscaleV1APIServer) ExpireNode(
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
db.NodeSetExpiry(
tx,
request.GetNodeId(),
types.NodeID(request.GetNodeId()),
now,
)

return db.GetNodeByID(tx, request.GetNodeId())
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -380,7 +380,7 @@ func (api headscaleV1APIServer) RenameNode(
return nil, err
}

return db.GetNodeByID(tx, request.GetNodeId())
return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -463,7 +463,7 @@ func (api headscaleV1APIServer) MoveNode(
ctx context.Context,
request *v1.MoveNodeRequest,
) (*v1.MoveNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -536,7 +536,7 @@ func (api headscaleV1APIServer) GetNodeRoutes(
ctx context.Context,
request *v1.GetNodeRoutesRequest,
) (*v1.GetNodeRoutesResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId())
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 3985af2

Please sign in to comment.