From 157930ef8a92e64d1869ed1b46001a4ac00df0f4 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 12 Jul 2024 16:45:11 +0200 Subject: [PATCH] replace ephemeral deletion logic this commit replaces the way we remove ephemeral nodes, currently they are deleted in a loop and we look at last seen time. This time is now only set when a node disconnects and there was a bug (#2006) where nodes that had never disconnected was deleted since they did not have a last seen. The new logic will start an expiry timer when the node disconnects and delete the node from the database when the timer is up. If the node reconnects within the expiry, the timer is cancelled. Fixes #2006 Signed-off-by: Kristoffer Dalby --- .github/workflows/test-integration.yaml | 1 + hscontrol/app.go | 54 ++-------- hscontrol/db/node.go | 125 +++++++++++++++--------- hscontrol/db/node_test.go | 24 +++++ hscontrol/db/preauth_keys_test.go | 72 -------------- hscontrol/poll.go | 15 +++ integration/general_test.go | 93 ++++++++++++++++++ 7 files changed, 222 insertions(+), 162 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 9581badaef..2c3d028849 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -40,6 +40,7 @@ jobs: - TestPingAllByIPPublicDERP - TestAuthKeyLogoutAndRelogin - TestEphemeral + - TestEphemeral2006DeletedTooQuickly - TestPingAllByHostname - TestTaildrop - TestResolveMagicDNS diff --git a/hscontrol/app.go b/hscontrol/app.go index 253c2671b1..2862c60e1b 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -90,6 +90,7 @@ type Headscale struct { db *db.HSDatabase ipAlloc *db.IPAllocator noisePrivateKey *key.MachinePrivate + ephemeralGC *db.EphemeralGarbageCollector DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer @@ -152,6 +153,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, err } + app.ephemeralGC = db.NewEphemeralGarbageCollector(func(ni types.NodeID) { + if err := app.db.DeleteEphemeralNode(ni); err != nil { + log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node") + } + }) + if cfg.OIDC.Issuer != "" { err = app.initOIDC() if err != nil { @@ -216,47 +223,6 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { http.Redirect(w, req, target, http.StatusFound) } -// deleteExpireEphemeralNodes deletes ephemeral node records that have not been -// seen for longer than h.cfg.EphemeralNodeInactivityTimeout. -func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) { - ticker := time.NewTicker(every) - - for { - select { - case <-ctx.Done(): - ticker.Stop() - return - case <-ticker.C: - var removed []types.NodeID - var changed []types.NodeID - if err := h.db.Write(func(tx *gorm.DB) error { - removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) - - return nil - }); err != nil { - log.Error().Err(err).Msg("database error while expiring ephemeral nodes") - continue - } - - if removed != nil { - ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: removed, - }) - } - - if changed != nil { - ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: changed, - }) - } - } - } -} - // expireExpiredNodes expires nodes that have an explicit expiry set // after that expiry time has passed. func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) { @@ -552,9 +518,7 @@ func (h *Headscale) Serve() error { return errEmptyInitialDERPMap } - expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background()) - defer expireEphemeralCancel() - go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval) + go h.ephemeralGC.Start() expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background()) defer expireNodeCancel() @@ -810,7 +774,7 @@ func (h *Headscale) Serve() error { Msg("Received signal to stop, shutting down gracefully") expireNodeCancel() - expireEphemeralCancel() + h.ephemeralGC.Close() trace("waiting for netmap stream to close") h.pollNetMapStreamWG.Wait() diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index e36d6ed131..9d2ba83507 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "sort" + "sync" "time" "github.com/juanfont/headscale/hscontrol/types" @@ -286,6 +287,20 @@ func DeleteNode(tx *gorm.DB, return changed, nil } +// DeleteEphemeralNode deletes a Node from the database, note that this method +// will remove it straight, and not notify any changes or consider any routes. +// It is intended for Ephemeral nodes. +func (hsdb *HSDatabase) DeleteEphemeralNode( + nodeID types.NodeID, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil { + return err + } + return nil + }) +} + // SetLastSeen sets a node's last seen field indicating that we // have recently communicating with this node. func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { @@ -660,51 +675,6 @@ func GenerateGivenName( return givenName, nil } -func DeleteExpiredEphemeralNodes(tx *gorm.DB, - inactivityThreshold time.Duration, -) ([]types.NodeID, []types.NodeID) { - users, err := ListUsers(tx) - if err != nil { - return nil, nil - } - - var expired []types.NodeID - var changedNodes []types.NodeID - for _, user := range users { - nodes, err := ListNodesByUser(tx, user.Name) - if err != nil { - return nil, nil - } - - for idx, node := range nodes { - if node.IsEphemeral() && node.LastSeen != nil && - time.Now(). - After(node.LastSeen.Add(inactivityThreshold)) { - expired = append(expired, node.ID) - - log.Info(). - Str("node", node.Hostname). - Msg("Ephemeral client removed from database") - - // empty isConnected map as ephemeral nodes are not routes - changed, err := DeleteNode(tx, nodes[idx], nil) - if err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Msg("🤮 Cannot delete ephemeral node from the database") - } - - changedNodes = append(changedNodes, changed...) - } - } - - // TODO(kradalby): needs to be moved out of transaction - } - - return expired, changedNodes -} - func ExpireExpiredNodes(tx *gorm.DB, lastCheck time.Time, ) (time.Time, types.StateUpdate, bool) { @@ -737,3 +707,68 @@ func ExpireExpiredNodes(tx *gorm.DB, return started, types.StateUpdate{}, false } + +type EphemeralGarbageCollector struct { + mu sync.Mutex + + deleteFunc func(types.NodeID) + toBeDeleted map[types.NodeID]*time.Timer + + deleteCh chan types.NodeID + cancelCh chan struct{} +} + +func NewEphemeralGarbageCollector(deleteFunc func(types.NodeID)) *EphemeralGarbageCollector { + return &EphemeralGarbageCollector{ + toBeDeleted: make(map[types.NodeID]*time.Timer), + deleteCh: make(chan types.NodeID, 10), + deleteFunc: deleteFunc, + } +} + +func (e *EphemeralGarbageCollector) Close() { + e.cancelCh <- struct{}{} +} + +func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) { + e.mu.Lock() + defer e.mu.Unlock() + + timer := time.NewTimer(expiry) + e.toBeDeleted[nodeID] = timer + + go func() { + select { + case _, ok := <-timer.C: + if ok { + e.deleteCh <- nodeID + } + } + }() +} + +func (e *EphemeralGarbageCollector) Cancel(nodeID types.NodeID) { + e.mu.Lock() + defer e.mu.Unlock() + + if timer, ok := e.toBeDeleted[nodeID]; ok { + timer.Stop() + delete(e.toBeDeleted, nodeID) + } +} + +func (e *EphemeralGarbageCollector) Start() { + for { + select { + case <-e.cancelCh: + return + case nodeID := <-e.deleteCh: + // TODO(kradalby): deadlock here? + e.mu.Lock() + delete(e.toBeDeleted, nodeID) + e.mu.Unlock() + + go e.deleteFunc(nodeID) + } + } +} diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index e95ee4ae33..cd8d04a4c8 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -599,3 +600,26 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(enabledRoutes, check.HasLen, 4) } + +func TestEphemeralGarbageCollector(t *testing.T) { + want := []types.NodeID{1, 3} + got := []types.NodeID{} + + e := NewEphemeralGarbageCollector(func(ni types.NodeID) { + got = append(got, ni) + }) + go e.Start() + + e.Schedule(1, 1*time.Second) + e.Schedule(2, 2*time.Second) + e.Schedule(3, 3*time.Second) + e.Schedule(4, 4*time.Second) + e.Cancel(2) + e.Cancel(4) + + time.Sleep(10 * time.Second) + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong nodes deleted, unexpected result (-want +got):\n%s", diff) + } +} diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 9cdcba8044..88d5b213f5 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -6,7 +6,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" - "gorm.io/gorm" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -127,77 +126,6 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { c.Assert(key.ID, check.Equals, pak.ID) } -func (*Suite) TestEphemeralKeyReusable(c *check.C) { - user, err := db.CreateUser("test7") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, true, true, nil, nil) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-time.Second * 30) - pakID := uint(pak.ID) - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - LastSeen: &now, - AuthKeyID: &pakID, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - - _, err = db.getNode("test7", "testest") - c.Assert(err, check.IsNil) - - db.Write(func(tx *gorm.DB) error { - DeleteExpiredEphemeralNodes(tx, time.Second*20) - return nil - }) - - // The machine record should have been deleted - _, err = db.getNode("test7", "testest") - c.Assert(err, check.NotNil) -} - -func (*Suite) TestEphemeralKeyNotReusable(c *check.C) { - user, err := db.CreateUser("test7") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-time.Second * 30) - pakId := uint(pak.ID) - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - LastSeen: &now, - AuthKeyID: &pakId, - } - db.DB.Save(&node) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.NotNil) - - _, err = db.getNode("test7", "testest") - c.Assert(err, check.IsNil) - - db.Write(func(tx *gorm.DB) error { - DeleteExpiredEphemeralNodes(tx, time.Second*20) - return nil - }) - - // The machine record should have been deleted - _, err = db.getNode("test7", "testest") - c.Assert(err, check.NotNil) -} - func (*Suite) TestExpirePreauthKey(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index d3c8211769..8122064b6d 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -135,6 +135,18 @@ func (m *mapSession) resetKeepAlive() { m.keepAliveTicker.Reset(m.keepAlive) } +func (m *mapSession) beforeServeLongPoll() { + if m.node.IsEphemeral() { + m.h.ephemeralGC.Cancel(m.node.ID) + } +} + +func (m *mapSession) afterServeLongPoll() { + if m.node.IsEphemeral() { + m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout) + } +} + // serve handles non-streaming requests. func (m *mapSession) serve() { // TODO(kradalby): A set todos to harden: @@ -180,6 +192,8 @@ func (m *mapSession) serve() { // //nolint:gocyclo func (m *mapSession) serveLongPoll() { + m.beforeServeLongPoll() + // Clean up the session when the client disconnects defer func() { m.cancelChMu.Lock() @@ -197,6 +211,7 @@ func (m *mapSession) serveLongPoll() { m.pollFailoverRoutes("node closing connection", m.node) } + m.afterServeLongPoll() m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) }() diff --git a/integration/general_test.go b/integration/general_test.go index 245e8f096a..6f7becf250 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -297,6 +297,99 @@ func TestEphemeral(t *testing.T) { } } +// TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not +// deleted by accident if they are still online and active. +func TestEphemeral2006DeletedTooQuickly(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + defer scenario.Shutdown() + + spec := map[string]int{ + "user1": len(MustTestVersions), + "user2": len(MustTestVersions), + } + + headscale, err := scenario.Headscale( + hsic.WithTestName("ephemeral2006"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s", + }), + ) + assertNoErrHeadscaleEnv(t, err) + + for userName, clientCount := range spec { + err = scenario.CreateUser(userName) + if err != nil { + t.Fatalf("failed to create user %s: %s", userName, err) + } + + err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) + if err != nil { + t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) + } + + key, err := scenario.CreatePreAuthKey(userName, true, true) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + // Take a break and let the ephemeral cleaner run twice. Nodes + // should still be around after this as they have stayed connected + // to tailscale the whole time. + time.Sleep(3 * time.Minute) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + for _, client := range allClients { + err := client.Down() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + + // Ensure the cleanup is running again after we take the nodes down. + time.Sleep(3 * time.Minute) + + for userName := range spec { + nodes, err := headscale.ListNodesInUser(userName) + if err != nil { + log.Error(). + Err(err). + Str("user", userName). + Msg("Error listing nodes in user") + + return + } + + if len(nodes) != 0 { + t.Fatalf("expected no nodes, got %d in user %s", len(nodes), userName) + } + } +} + func TestPingAllByHostname(t *testing.T) { IntegrationSkip(t) t.Parallel()