From 3985af23ac6724713481b24455709cddfcb8e646 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 29 Feb 2024 12:32:08 +0100 Subject: [PATCH] NodeID type and simplify policy compile call Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 6 +- hscontrol/db/node.go | 29 +++--- hscontrol/db/node_test.go | 14 +-- hscontrol/db/routes.go | 6 +- hscontrol/grpcv1.go | 18 ++-- hscontrol/mapper/mapper.go | 100 +++++++++++++-------- hscontrol/mapper/mapper_test.go | 4 +- hscontrol/mapper/tail.go | 30 +++---- hscontrol/mapper/tail_test.go | 9 +- hscontrol/policy/acls.go | 41 ++++----- hscontrol/policy/acls_test.go | 65 +++++++------- hscontrol/poll.go | 152 +++++++++++++------------------- hscontrol/types/common.go | 4 +- hscontrol/types/node.go | 24 ++++- 14 files changed, 246 insertions(+), 256 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 5313077bb1..77d2a86506 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -98,7 +98,6 @@ type Headscale struct { registrationCache *cache.Cache - shutdownChan chan struct{} pollNetMapStreamWG sync.WaitGroup } @@ -504,7 +503,7 @@ func (h *Headscale) Serve() error { // Fetch an initial DERP Map before we start serving h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - h.mapper = mapper.NewMapper(h.DERPMap, h.cfg) + h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap()) if h.cfg.DERP.ServerEnabled { // When embedded DERP is enabled we always need a STUN server @@ -745,7 +744,6 @@ func (h *Headscale) Serve() error { } // Handle common process-killing signals so we can gracefully shut down: - h.shutdownChan = make(chan struct{}) sigc := make(chan os.Signal, 1) signal.Notify(sigc, syscall.SIGHUP, @@ -788,8 +786,6 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") - close(h.shutdownChan) - h.pollNetMapStreamWG.Wait() // Gracefully shut down servers diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index d02c2d3944..2dd039990b 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -34,27 +34,22 @@ var ( ) ) -func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) { +func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListPeers(rx, node) + return ListPeers(rx, nodeID) }) } // ListPeers returns all peers of node, regardless of any Policy or if the node is expired. -func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) { - log.Trace(). - Caller(). - Str("node", node.Hostname). - Msg("Finding direct peers") - +func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { nodes := types.Nodes{} if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). Preload("Routes"). - Where("node_key <> ?", - node.NodeKey.String()).Find(&nodes).Error; err != nil { + Where("id <> ?", + nodeID).Find(&nodes).Error; err != nil { return types.Nodes{}, err } @@ -119,14 +114,14 @@ func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { return nil, ErrNodeNotFound } -func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { +func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return GetNodeByID(rx, id) }) } // GetNodeByID finds a Node by ID and returns the Node struct. -func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) { +func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) { mach := types.Node{} if result := tx. Preload("AuthKey"). @@ -197,7 +192,7 @@ func GetNodeByAnyKey( } func (hsdb *HSDatabase) SetTags( - nodeID uint64, + nodeID types.NodeID, tags []string, ) error { return hsdb.Write(func(tx *gorm.DB) error { @@ -208,7 +203,7 @@ func (hsdb *HSDatabase) SetTags( // SetTags takes a Node struct pointer and update the forced tags. func SetTags( tx *gorm.DB, - nodeID uint64, + nodeID types.NodeID, tags []string, ) error { if len(tags) == 0 { @@ -256,7 +251,7 @@ func RenameNode(tx *gorm.DB, return nil } -func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error { +func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error { return hsdb.Write(func(tx *gorm.DB) error { return NodeSetExpiry(tx, nodeID, expiry) }) @@ -264,7 +259,7 @@ func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error { // NodeSetExpiry takes a Node struct and a new expiry time. func NodeSetExpiry(tx *gorm.DB, - nodeID uint64, expiry time.Time, + nodeID types.NodeID, expiry time.Time, ) error { return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error } @@ -296,7 +291,7 @@ func DeleteNode(tx *gorm.DB, // UpdateLastSeen sets a node's last seen field indicating that we // have recently communicating with this node. -func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error { +func UpdateLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 5e8eb29435..3622cc6067 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -142,7 +142,7 @@ func (s *Suite) TestListPeers(c *check.C) { machineKey := key.NewMachine() node := types.Node{ - ID: uint64(index), + ID: types.NodeID(index), MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode" + strconv.Itoa(index), @@ -156,7 +156,7 @@ func (s *Suite) TestListPeers(c *check.C) { node0ByID, err := db.GetNodeByID(0) c.Assert(err, check.IsNil) - peersOfNode0, err := db.ListPeers(node0ByID) + peersOfNode0, err := db.ListPeers(node0ByID.ID) c.Assert(err, check.IsNil) c.Assert(len(peersOfNode0), check.Equals, 9) @@ -189,7 +189,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { machineKey := key.NewMachine() node := types.Node{ - ID: uint64(index), + ID: types.NodeID(index), MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), IPAddresses: types.NodeAddresses{ @@ -232,16 +232,16 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User) c.Assert(err, check.IsNil) - adminPeers, err := db.ListPeers(adminNode) + adminPeers, err := db.ListPeers(adminNode.ID) c.Assert(err, check.IsNil) - testPeers, err := db.ListPeers(testNode) + testPeers, err := db.ListPeers(testNode.ID) c.Assert(err, check.IsNil) - adminRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, adminNode, adminPeers) + adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) c.Assert(err, check.IsNil) - testRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, testNode, testPeers) + testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) c.Assert(err, check.IsNil) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 1ee144a719..377473b40f 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -400,7 +400,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) { for prefix, exists := range advertisedRoutes { if !exists { route := types.Route{ - NodeID: node.ID, + NodeID: node.ID.Uint64(), Prefix: types.IPPrefix(prefix), Advertised: true, Enabled: false, @@ -641,7 +641,7 @@ func EnableAutoApprovedRoutes( if err != nil { log.Err(err). Str("advertisedRoute", advertisedRoute.String()). - Uint64("nodeId", node.ID). + Uint64("nodeId", node.ID.Uint64()). Msg("Failed to resolve autoApprovers for advertised route") return nil, err @@ -687,7 +687,7 @@ func EnableAutoApprovedRoutes( if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). - Uint64("nodeId", node.ID). + Uint64("nodeId", node.ID.Uint64()). Msg("Failed to enable approved route") return nil, err diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 379502c725..755f83eec6 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -222,7 +222,7 @@ func (api headscaleV1APIServer) GetNode( ctx context.Context, request *v1.GetNodeRequest, ) (*v1.GetNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) if err != nil { return nil, err } @@ -248,12 +248,12 @@ func (api headscaleV1APIServer) SetTags( } node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { - err := db.SetTags(tx, request.GetNodeId(), request.GetTags()) + err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags()) if err != nil { return nil, err } - return db.GetNodeByID(tx, request.GetNodeId()) + return db.GetNodeByID(tx, types.NodeID(request.GetNodeId())) }) if err != nil { return &v1.SetTagsResponse{ @@ -296,7 +296,7 @@ func (api headscaleV1APIServer) DeleteNode( ctx context.Context, request *v1.DeleteNodeRequest, ) (*v1.DeleteNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) if err != nil { return nil, err } @@ -330,11 +330,11 @@ func (api headscaleV1APIServer) ExpireNode( node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { db.NodeSetExpiry( tx, - request.GetNodeId(), + types.NodeID(request.GetNodeId()), now, ) - return db.GetNodeByID(tx, request.GetNodeId()) + return db.GetNodeByID(tx, types.NodeID(request.GetNodeId())) }) if err != nil { return nil, err @@ -380,7 +380,7 @@ func (api headscaleV1APIServer) RenameNode( return nil, err } - return db.GetNodeByID(tx, request.GetNodeId()) + return db.GetNodeByID(tx, types.NodeID(request.GetNodeId())) }) if err != nil { return nil, err @@ -463,7 +463,7 @@ func (api headscaleV1APIServer) MoveNode( ctx context.Context, request *v1.MoveNodeRequest, ) (*v1.MoveNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) if err != nil { return nil, err } @@ -536,7 +536,7 @@ func (api headscaleV1APIServer) GetNodeRoutes( ctx context.Context, request *v1.GetNodeRoutesRequest, ) (*v1.GetNodeRoutesResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) if err != nil { return nil, err } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index e879b8517a..cf6320850a 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -15,6 +15,7 @@ import ( "time" mapset "github.com/deckarep/golang-set/v2" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -24,6 +25,7 @@ import ( "tailscale.com/smallzstd" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/key" ) const ( @@ -46,14 +48,18 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_ // - Create a "minifier" that removes info not needed for the node // - some sort of batching, wait for 5 or 60 seconds before sending +type nodeReq struct { + nodeID types.NodeID + done chan<- struct{} +} + type Mapper struct { // Configuration // TODO(kradalby): figure out if this is the format we want this in - derpMap *tailcfg.DERPMap - baseDomain string - dnsCfg *tailcfg.DNSConfig - logtail bool - randomClientPort bool + db *db.HSDatabase + cfg *types.Config + derpMap *tailcfg.DERPMap + isMostlyConnected map[key.MachinePublic]bool uid string created time.Time @@ -66,17 +72,18 @@ type patch struct { } func NewMapper( - derpMap *tailcfg.DERPMap, + db *db.HSDatabase, cfg *types.Config, + derpMap *tailcfg.DERPMap, + isMostlyConnected map[key.MachinePublic]bool, ) *Mapper { uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) return &Mapper{ - derpMap: derpMap, - baseDomain: cfg.BaseDomain, - dnsCfg: cfg.DNSConfig, - logtail: cfg.LogTail.Enabled, - randomClientPort: cfg.RandomizeClientPort, + db: db, + cfg: cfg, + derpMap: derpMap, + isMostlyConnected: isMostlyConnected, uid: uid, created: time.Now(), @@ -201,9 +208,7 @@ func (m *Mapper) fullMapResponse( capVer, peers, peers, - m.baseDomain, - m.dnsCfg, - m.randomClientPort, + m.cfg, ) if err != nil { return nil, err @@ -216,9 +221,13 @@ func (m *Mapper) fullMapResponse( func (m *Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - peers types.Nodes, pol *policy.ACLPolicy, ) ([]byte, error) { + peers, err := m.ListPeers(node.ID) + if err != nil { + return nil, err + } + resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) if err != nil { return nil, err @@ -270,23 +279,25 @@ func (m *Mapper) DERPMapResponse( func (m *Mapper) PeerChangedResponse( mapRequest tailcfg.MapRequest, node *types.Node, - peers types.Nodes, changed types.Nodes, pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { resp := m.baseMapResponse() - err := appendPeerChanges( + peers, err := m.ListPeers(node.ID) + if err != nil { + return nil, err + } + + err = appendPeerChanges( &resp, pol, node, mapRequest.Version, peers, changed, - m.baseDomain, - m.dnsCfg, - m.randomClientPort, + m.cfg, ) if err != nil { return nil, err @@ -457,7 +468,7 @@ func (m *Mapper) baseWithConfigMapResponse( ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, pol, m.dnsCfg, m.baseDomain, m.randomClientPort) + tailnode, err := tailNode(node, capVer, pol, m.cfg) if err != nil { return nil, err } @@ -465,7 +476,7 @@ func (m *Mapper) baseWithConfigMapResponse( resp.DERPMap = m.derpMap - resp.Domain = m.baseDomain + resp.Domain = m.cfg.BaseDomain // Do not instruct clients to collect services we do not // support or do anything with them @@ -474,12 +485,26 @@ func (m *Mapper) baseWithConfigMapResponse( resp.KeepAlive = false resp.Debug = &tailcfg.Debug{ - DisableLogTail: !m.logtail, + DisableLogTail: !m.cfg.LogTail.Enabled, } return &resp, nil } +func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { + peers, err := m.db.ListPeers(nodeID) + if err != nil { + return nil, err + } + + for _, peer := range peers { + online := m.isMostlyConnected[peer.MachineKey] + peer.IsOnline = &online + } + + return peers, nil +} + func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { ret := make(types.Nodes, 0) @@ -500,37 +525,36 @@ func appendPeerChanges( capVer tailcfg.CapabilityVersion, peers types.Nodes, changed types.Nodes, - baseDomain string, - dnsCfg *tailcfg.DNSConfig, - randomClientPort bool, + cfg *types.Config, ) error { fullChange := len(peers) == len(changed) - rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( - pol, - node, - peers, - ) + packetFilter, err := pol.CompileFilterRules(append(peers, node)) + if err != nil { + return err + } + + sshPolicy, err := pol.CompileSSHPolicy(node, peers) if err != nil { return err } // If there are filter rules present, see if there are any nodes that cannot // access eachother at all and remove them from the peers. - if len(rules) > 0 { - changed = policy.FilterNodesByACL(node, changed, rules) + if len(packetFilter) > 0 { + changed = policy.FilterNodesByACL(node, changed, packetFilter) } - profiles := generateUserProfiles(node, changed, baseDomain) + profiles := generateUserProfiles(node, changed, cfg.BaseDomain) dnsConfig := generateDNSConfig( - dnsCfg, - baseDomain, + cfg.DNSConfig, + cfg.BaseDomain, node, peers, ) - tailPeers, err := tailNodes(changed, capVer, pol, dnsCfg, baseDomain, randomClientPort) + tailPeers, err := tailNodes(changed, capVer, pol, cfg) if err != nil { return err } @@ -546,7 +570,7 @@ func appendPeerChanges( resp.PeersChanged = tailPeers } resp.DNSConfig = dnsConfig - resp.PacketFilter = policy.ReduceFilterRules(node, rules) + resp.PacketFilter = policy.ReduceFilterRules(node, packetFilter) resp.UserProfiles = profiles resp.SSHPolicy = sshPolicy diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 51dbd2d703..3f4d6892e1 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -466,8 +466,10 @@ func Test_fullMapResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mappy := NewMapper( - tt.derpMap, + nil, tt.cfg, + tt.derpMap, + nil, ) got, err := mappy.fullMapResponse( diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index c10da4debd..97d12e862c 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -3,12 +3,10 @@ package mapper import ( "fmt" "net/netip" - "strconv" "time" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/samber/lo" "tailscale.com/tailcfg" ) @@ -17,9 +15,7 @@ func tailNodes( nodes types.Nodes, capVer tailcfg.CapabilityVersion, pol *policy.ACLPolicy, - dnsConfig *tailcfg.DNSConfig, - baseDomain string, - randomClientPort bool, + cfg *types.Config, ) ([]*tailcfg.Node, error) { tNodes := make([]*tailcfg.Node, len(nodes)) @@ -28,9 +24,7 @@ func tailNodes( node, capVer, pol, - dnsConfig, - baseDomain, - randomClientPort, + cfg, ) if err != nil { return nil, err @@ -48,9 +42,7 @@ func tailNode( node *types.Node, capVer tailcfg.CapabilityVersion, pol *policy.ACLPolicy, - dnsConfig *tailcfg.DNSConfig, - baseDomain string, - randomClientPort bool, + cfg *types.Config, ) (*tailcfg.Node, error) { addrs := node.IPAddresses.Prefixes() @@ -85,7 +77,7 @@ func tailNode( keyExpiry = time.Time{} } - hostname, err := node.GetFQDN(dnsConfig, baseDomain) + hostname, err := node.GetFQDN(cfg.DNSConfig, cfg.BaseDomain) if err != nil { return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) } @@ -94,12 +86,10 @@ func tailNode( tags = lo.Uniq(append(tags, node.ForcedTags...)) tNode := tailcfg.Node{ - ID: tailcfg.NodeID(node.ID), // this is the actual ID - StableID: tailcfg.StableNodeID( - strconv.FormatUint(node.ID, util.Base10), - ), // in headscale, unlike tailcontrol server, IDs are permanent - Name: hostname, - Cap: capVer, + ID: tailcfg.NodeID(node.ID), // this is the actual ID + StableID: node.ID.StableID(), + Name: hostname, + Cap: capVer, User: tailcfg.UserID(node.UserID), @@ -133,7 +123,7 @@ func tailNode( tailcfg.CapabilitySSH: []tailcfg.RawMessage{}, } - if randomClientPort { + if cfg.RandomizeClientPort { tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} } } else { @@ -143,7 +133,7 @@ func tailNode( tailcfg.CapabilitySSH, } - if randomClientPort { + if cfg.RandomizeClientPort { tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort) } } diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index f6e370c4a5..e79d9dc567 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -182,13 +182,16 @@ func TestTailNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + cfg := &types.Config{ + BaseDomain: tt.baseDomain, + DNSConfig: tt.dnsConfig, + RandomizeClientPort: false, + } got, err := tailNode( tt.node, 0, tt.pol, - tt.dnsConfig, - tt.baseDomain, - false, + cfg, ) if (err != nil) != tt.wantErr { diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 2ccc56b4a4..b409578188 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -114,7 +114,7 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) { return &policy, nil } -func GenerateFilterAndSSHRules( +func GenerateFilterAndSSHRulesForTests( policy *ACLPolicy, node *types.Node, peers types.Nodes, @@ -124,40 +124,31 @@ func GenerateFilterAndSSHRules( return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.generateFilterRules(node, peers) + rules, err := policy.CompileFilterRules(append(peers, node)) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - var sshPolicy *tailcfg.SSHPolicy - sshRules, err := policy.generateSSHRules(node, peers) + sshPolicy, err := policy.CompileSSHPolicy(node, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } - log.Trace(). - Interface("SSH", sshRules). - Str("node", node.GivenName). - Msg("SSH rules") - - if sshPolicy == nil { - sshPolicy = &tailcfg.SSHPolicy{} - } - sshPolicy.Rules = sshRules - return rules, sshPolicy, nil } -// generateFilterRules takes a set of nodes and an ACLPolicy and generates a +// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. -func (pol *ACLPolicy) generateFilterRules( - node *types.Node, - peers types.Nodes, +func (pol *ACLPolicy) CompileFilterRules( + nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { + if pol == nil { + return tailcfg.FilterAllowAll, nil + } + rules := []tailcfg.FilterRule{} - nodes := append(peers, node) for index, acl := range pol.ACLs { if acl.Action != "accept" { @@ -279,10 +270,14 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F return ret } -func (pol *ACLPolicy) generateSSHRules( +func (pol *ACLPolicy) CompileSSHPolicy( node *types.Node, peers types.Nodes, -) ([]*tailcfg.SSHRule, error) { +) (*tailcfg.SSHPolicy, error) { + if pol == nil { + return nil, nil + } + rules := []*tailcfg.SSHRule{} acceptAction := tailcfg.SSHAction{ @@ -393,7 +388,9 @@ func (pol *ACLPolicy) generateSSHRules( }) } - return rules, nil + return &tailcfg.SSHPolicy{ + Rules: rules, + }, nil } func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index ff18dd0573..706588cadd 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -385,11 +385,12 @@ acls: return } - rules, err := pol.generateFilterRules(&types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.100.100.100"), + rules, err := pol.CompileFilterRules(types.Nodes{ + &types.Node{ + IPAddresses: types.NodeAddresses{ + netip.MustParseAddr("100.100.100.100"), + }, }, - }, types.Nodes{ &types.Node{ IPAddresses: types.NodeAddresses{ netip.MustParseAddr("200.200.200.200"), @@ -546,7 +547,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.generateFilterRules(&types.Node{}, types.Nodes{}) + rules, err := pol.CompileFilterRules(types.Nodes{}) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -562,7 +563,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -581,7 +582,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -597,7 +598,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } @@ -1724,8 +1725,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { pol ACLPolicy } type args struct { - node *types.Node - peers types.Nodes + nodes types.Nodes } tests := []struct { name string @@ -1755,13 +1755,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - node: &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + nodes: types.Nodes{ + &types.Node{ + IPAddresses: types.NodeAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + }, }, }, - peers: types.Nodes{}, }, want: []tailcfg.FilterRule{ { @@ -1800,14 +1801,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - node: &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + nodes: types.Nodes{ + &types.Node{ + IPAddresses: types.NodeAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + }, + User: types.User{Name: "mickael"}, }, - User: types.User{Name: "mickael"}, - }, - peers: types.Nodes{ &types.Node{ IPAddresses: types.NodeAddresses{ netip.MustParseAddr("100.64.0.2"), @@ -1846,9 +1847,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.field.pol.generateFilterRules( - tt.args.node, - tt.args.peers, + got, err := tt.field.pol.CompileFilterRules( + tt.args.nodes, ) if (err != nil) != tt.wantErr { t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) @@ -1980,9 +1980,8 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rules, _ := tt.pol.generateFilterRules( - tt.node, - tt.peers, + rules, _ := tt.pol.CompileFilterRules( + append(tt.peers, tt.node), ) got := ReduceFilterRules(tt.node, rules) @@ -3048,7 +3047,7 @@ func TestSSHRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.generateSSHRules(&tt.node, tt.peers) + got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) assert.NoError(t, err) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -3155,7 +3154,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3206,7 +3205,7 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3265,7 +3264,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3335,7 +3334,7 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{nodes2}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) assert.NoError(t, err) want := []tailcfg.FilterRule{ diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 7a04123e18..0120513f40 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -6,6 +6,7 @@ import ( "math/rand/v2" "net/http" "strings" + "sync" "time" "github.com/juanfont/headscale/hscontrol/db" @@ -26,7 +27,10 @@ type contextKey string const nodeNameContextKey = contextKey("nodeName") -type UpdateNode func() +type sessionManager struct { + mu sync.RWMutex + sess map[types.NodeID]*mapSession +} type mapSession struct { h *Headscale @@ -35,6 +39,9 @@ type mapSession struct { capVer tailcfg.CapabilityVersion mapper *mapper.Mapper + ch chan types.StateUpdate + cancelCh chan struct{} + node *types.Node w http.ResponseWriter @@ -51,6 +58,17 @@ func (h *Headscale) newMapSession( node *types.Node, ) *mapSession { warnf, tracef, infof, errf := logPollFunc(req, node) + + // Use a buffered channel in case a node is not fully ready + // to receive a message to make sure we dont block the entire + // notifier. + // 12 is arbitrarily chosen. + chanSize := 3 + if size, ok := envknob.LookupInt("HEADSCALE_TUNING_POLL_QUEUE_SIZE"); ok { + chanSize = size + } + updateChan := make(chan types.StateUpdate, chanSize) + return &mapSession{ h: h, ctx: ctx, @@ -60,6 +78,9 @@ func (h *Headscale) newMapSession( capVer: req.Version, mapper: h.mapper, + ch: updateChan, + cancelCh: make(chan struct{}), + // Loggers warnf: warnf, infof: infof, @@ -68,6 +89,10 @@ func (h *Headscale) newMapSession( } } +func (m *mapSession) close() { + m.cancelCh <- struct{}{} +} + func (m *mapSession) isStreaming() bool { return m.req.Stream && !m.req.ReadOnly } @@ -142,21 +167,6 @@ func (m *mapSession) serve() { m.h.pollNetMapStreamWG.Add(1) defer m.h.pollNetMapStreamWG.Done() - // Use a buffered channel in case a node is not fully ready - // to receive a message to make sure we dont block the entire - // notifier. - // 12 is arbitrarily chosen. - chanSize := 3 - if size, ok := envknob.LookupInt("HEADSCALE_TUNING_POLL_QUEUE_SIZE"); ok { - chanSize = size - } - updateChan := make(chan types.StateUpdate, chanSize) - defer closeChanWithLog(updateChan, m.node.Hostname, "updateChan") - - // Register the node's update channel - m.h.nodeNotifier.AddNode(m.node.MachineKey, updateChan) - defer m.h.nodeNotifier.RemoveNode(m.node.MachineKey) - // update ACLRules with peer informations (to update server tags if necessary) if m.h.ACLPolicy != nil { // update routes with peer information @@ -171,21 +181,7 @@ func (m *mapSession) serve() { m.tracef("Sending initial map") - peers, err := m.h.db.ListPeers(m.node) - if err != nil { - m.errf(err, "Failed to list peers when opening poller") - http.Error(m.w, "", http.StatusInternalServerError) - - return - } - - isConnected := m.h.nodeNotifier.ConnectedMap() - for _, peer := range peers { - online := isConnected[peer.MachineKey] - peer.IsOnline = &online - } - - mapResp, err := m.mapper.FullMapResponse(m.req, m.node, peers, m.h.ACLPolicy) + mapResp, err := m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy) if err != nil { m.errf(err, "Failed to create MapResponse") http.Error(m.w, "", http.StatusInternalServerError) @@ -232,35 +228,7 @@ func (m *mapSession) serve() { for { m.tracef("Waiting for update on stream channel") select { - case <-keepAliveTicker.C: - data, err := m.mapper.KeepAliveResponse(m.req, m.node) - if err != nil { - m.errf(err, "Error generating the keep alive msg") - - return - } - _, err = m.w.Write(data) - if err != nil { - m.errf(err, "Cannot write keep alive message") - - return - } - if flusher, ok := m.w.(http.Flusher); ok { - flusher.Flush() - } else { - log.Error().Msg("Failed to create http flusher") - - return - } - - // This goroutine is not ideal, but we have a potential issue here - // where it blocks too long and that holds up updates. - // One alternative is to split these different channels into - // goroutines, but then you might have a problem without a lock - // if a keepalive is written at the same time as an update. - go m.h.updateNodeOnlineStatus(true, m.node) - - case update := <-updateChan: + case update := <-m.ch: m.tracef("Received update") var data []byte var err error @@ -275,43 +243,16 @@ func (m *mapSession) serve() { return } - peers, err := m.h.db.ListPeers(m.node) - if err != nil { - m.errf(err, "Failed to list peers when opening poller") - http.Error(m.w, "", http.StatusInternalServerError) - - return - } - - isConnected := m.h.nodeNotifier.ConnectedMap() - for _, peer := range peers { - online := isConnected[peer.MachineKey] - peer.IsOnline = &online - } - startMapResp := time.Now() switch update.Type { case types.StateFullUpdate: m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, peers, m.h.ACLPolicy) + data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy) case types.StatePeerChanged: m.tracef(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message)) - isConnectedMap := m.h.nodeNotifier.ConnectedMap() - for _, node := range update.ChangeNodes { - // If a node is not reported to be online, it might be - // because the value is outdated, check with the notifier. - // However, if it is set to Online, and not in the notifier, - // this might be because it has announced itself, but not - // reached the stage to actually create the notifier channel. - if node.IsOnline != nil && !*node.IsOnline { - isOnline := isConnectedMap[node.MachineKey] - node.IsOnline = &isOnline - } - } - - data, err = m.mapper.PeerChangedResponse(m.req, m.node, peers, update.ChangeNodes, m.h.ACLPolicy, update.Message) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, update.ChangeNodes, m.h.ACLPolicy, update.Message) case types.StatePeerChangedPatch: m.tracef("Sending PeerChangedPatch MapResponse") data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) @@ -364,6 +305,34 @@ func (m *mapSession) serve() { m.infof("update sent") } + case <-keepAliveTicker.C: + data, err := m.mapper.KeepAliveResponse(m.req, m.node) + if err != nil { + m.errf(err, "Error generating the keep alive msg") + + return + } + _, err = m.w.Write(data) + if err != nil { + m.errf(err, "Cannot write keep alive message") + + return + } + if flusher, ok := m.w.(http.Flusher); ok { + flusher.Flush() + } else { + log.Error().Msg("Failed to create http flusher") + + return + } + + // This goroutine is not ideal, but we have a potential issue here + // where it blocks too long and that holds up updates. + // One alternative is to split these different channels into + // goroutines, but then you might have a problem without a lock + // if a keepalive is written at the same time as an update. + // go m.h.updateNodeOnlineStatus(true, m.node) + case <-ctx.Done(): m.tracef("The client has closed the connection") @@ -375,11 +344,10 @@ func (m *mapSession) serve() { // The connection has been closed, so we can stop polling. return - case <-m.h.shutdownChan: - m.tracef("The long-poll handler is shutting down") - + case <-m.cancelCh: return } + } } diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index ceeceea004..eb91b73d7a 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -185,12 +185,12 @@ func (su *StateUpdate) Empty() bool { return false } -func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate { +func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { return StateUpdate{ Type: StatePeerChangedPatch, ChangePatches: []*tailcfg.PeerChange{ { - NodeID: tailcfg.NodeID(nodeID), + NodeID: nodeID.NodeID(), KeyExpiry: &expiry, }, }, diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 69004bfdd2..8389f421f1 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -7,11 +7,13 @@ import ( "fmt" "net/netip" "sort" + "strconv" "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" @@ -27,9 +29,23 @@ var ( ErrNodeUserHasNoName = errors.New("node user has no name") ) +type NodeID uint64 + +func (id NodeID) StableID() tailcfg.StableNodeID { + return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10)) +} + +func (id NodeID) NodeID() tailcfg.NodeID { + return tailcfg.NodeID(id) +} + +func (id NodeID) Uint64() uint64 { + return uint64(id) +} + // Node is a Headscale client. type Node struct { - ID uint64 `gorm:"primary_key"` + ID NodeID `gorm:"primary_key"` // MachineKeyDatabaseField is the string representation of MachineKey // it is _only_ used for reading and writing the key to the @@ -319,7 +335,7 @@ func (node *Node) AfterFind(tx *gorm.DB) error { func (node *Node) Proto() *v1.Node { nodeProto := &v1.Node{ - Id: node.ID, + Id: uint64(node.ID), MachineKey: node.MachineKey.String(), NodeKey: node.NodeKey.String(), @@ -486,8 +502,8 @@ func (nodes Nodes) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (nodes Nodes) IDMap() map[uint64]*Node { - ret := map[uint64]*Node{} +func (nodes Nodes) IDMap() map[NodeID]*Node { + ret := map[NodeID]*Node{} for _, node := range nodes { ret[node.ID] = node