diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 43dfd2b08e..78db626b23 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -309,7 +309,7 @@ func (h *Headscale) handleAuthKey( machine.NodeKey = nodeKey machine.AuthKeyID = uint(pak.ID) - err := h.db.RefreshMachine(machine, registerRequest.Expiry) + err := h.db.MachineSetExpiry(machine, registerRequest.Expiry) if err != nil { log.Error(). Caller(). @@ -510,7 +510,8 @@ func (h *Headscale) handleMachineLogOut( Str("machine", machine.Hostname). Msg("Client requested logout") - err := h.db.ExpireMachine(&machine) + now := time.Now() + err := h.db.MachineSetExpiry(&machine, now) if err != nil { log.Error(). Caller(). @@ -552,7 +553,7 @@ func (h *Headscale) handleMachineLogOut( } if machine.IsEphemeral() { - err = h.db.HardDeleteMachine(&machine) + err = h.db.DeleteMachine(&machine) if err != nil { log.Error(). Err(err). diff --git a/hscontrol/db/machine.go b/hscontrol/db/machine.go index 8a3e22815a..574f53ed45 100644 --- a/hscontrol/db/machine.go +++ b/hscontrol/db/machine.go @@ -39,6 +39,10 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.listPeers(machine) +} + +func (hsdb *HSDatabase) listPeers(machine *types.Machine) (types.Machines, error) { log.Trace(). Caller(). Str("machine", machine.Hostname). @@ -69,6 +73,10 @@ func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.listMachines() +} + +func (hsdb *HSDatabase) listMachines() ([]types.Machine, error) { machines := []types.Machine{} if err := hsdb.db. Preload("AuthKey"). @@ -86,6 +94,10 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machine hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.listMachinesByGivenName(givenName) +} + +func (hsdb *HSDatabase) listMachinesByGivenName(givenName string) (types.Machines, error) { machines := types.Machines{} if err := hsdb.db. Preload("AuthKey"). @@ -126,17 +138,16 @@ func (hsdb *HSDatabase) GetMachineByGivenName( hsdb.mu.RLock() defer hsdb.mu.RUnlock() - machines, err := hsdb.ListMachinesByUser(user) - if err != nil { + machine := types.Machine{} + if err := hsdb.db. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + Where("given_name = ?", givenName).First(&machine).Error; err != nil { return nil, err } - for _, m := range machines { - if m.GivenName == givenName { - return &m, nil - } - } - return nil, ErrMachineNotFound } @@ -222,10 +233,7 @@ func (hsdb *HSDatabase) GetMachineByAnyKey( return &machine, nil } -// TODO(kradalby): rename this, it sounds like a mix of getting and setting to db -// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database -// and updates it with the latest data from the database. -func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error { +func (hsdb *HSDatabase) MachineReloadFromDatabase(machine *types.Machine) error { hsdb.mu.RLock() defer hsdb.mu.RUnlock() @@ -250,37 +258,18 @@ func (hsdb *HSDatabase) SetTags( newTags = append(newTags, tag) } } - machine.ForcedTags = newTags - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: []uint64{machine.ID}, - }, machine.MachineKey) - - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + ForcedTags: newTags, + }).Error; err != nil { return fmt.Errorf("failed to update tags for machine in the database: %w", err) } - return nil -} - -// ExpireMachine takes a Machine struct and sets the expire field to now. -func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - now := time.Now() - machine.Expiry = &now - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ Type: types.StatePeerChanged, Changed: []uint64{machine.ID}, }, machine.MachineKey) - if err := hsdb.db.Save(machine).Error; err != nil { - return fmt.Errorf("failed to expire machine in the database: %w", err) - } - return nil } @@ -306,57 +295,73 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er } machine.GivenName = newName + if err := hsdb.db.Model(machine).Updates(types.Machine{ + GivenName: newName, + }).Error; err != nil { + return fmt.Errorf("failed to rename machine in the database: %w", err) + } + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ Type: types.StatePeerChanged, Changed: []uint64{machine.ID}, }, machine.MachineKey) - if err := hsdb.db.Save(machine).Error; err != nil { - return fmt.Errorf("failed to rename machine in the database: %w", err) - } - return nil } -// RefreshMachine takes a Machine struct and a new expiry time. -func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error { +// MachineSetExpiry takes a Machine struct and a new expiry time. +func (hsdb *HSDatabase) MachineSetExpiry(machine *types.Machine, expiry time.Time) error { hsdb.mu.Lock() defer hsdb.mu.Unlock() - now := time.Now() - - machine.LastSuccessfulUpdate = &now - machine.Expiry = &expiry + return hsdb.machineSetExpiry(machine, expiry) +} - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: []uint64{machine.ID}, - }, machine.MachineKey) +func (hsdb *HSDatabase) machineSetExpiry(machine *types.Machine, expiry time.Time) error { + now := time.Now() - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + LastSuccessfulUpdate: &now, + Expiry: &expiry, + }).Error; err != nil { return fmt.Errorf( "failed to refresh machine (update expiration) in the database: %w", err, ) } + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, machine.MachineKey) + return nil } -// DeleteMachine softs deletes a Machine from the database. +// DeleteMachine deletes a Machine from the database. func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { hsdb.mu.Lock() defer hsdb.mu.Unlock() - err := hsdb.DeleteMachineRoutes(machine) + return hsdb.deleteMachine(machine) +} + +func (hsdb *HSDatabase) deleteMachine(machine *types.Machine) error { + err := hsdb.deleteMachineRoutes(machine) if err != nil { return err } - if err := hsdb.db.Delete(&machine).Error; err != nil { + // Unscoped causes the machine to be fully removed from the database. + if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { return err } + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(machine.ID)}, + }) + return nil } @@ -364,30 +369,12 @@ func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { hsdb.mu.Lock() defer hsdb.mu.Unlock() - return hsdb.db.Updates(types.Machine{ - ID: machine.ID, + return hsdb.db.Model(machine).Updates(types.Machine{ LastSeen: machine.LastSeen, LastSuccessfulUpdate: machine.LastSuccessfulUpdate, }).Error } -// HardDeleteMachine hard deletes a Machine from the database. -func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - err := hsdb.DeleteMachineRoutes(machine) - if err != nil { - return err - } - - if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { - return err - } - - return nil -} - func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( cache *cache.Cache, nodeKeyStr string, @@ -517,9 +504,9 @@ func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.No hsdb.mu.Lock() defer hsdb.mu.Unlock() - machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey) - - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + NodeKey: util.NodePublicKeyStripPrefix(nodeKey), + }).Error; err != nil { return err } @@ -529,14 +516,14 @@ func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.No // MachineSetMachineKey sets the machine key of a machine and saves it to the database. func (hsdb *HSDatabase) MachineSetMachineKey( machine *types.Machine, - nodeKey key.MachinePublic, + machineKey key.MachinePublic, ) error { hsdb.mu.Lock() defer hsdb.mu.Unlock() - machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey) - - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + MachineKey: util.MachinePublicKeyStripPrefix(machineKey), + }).Error; err != nil { return err } @@ -561,6 +548,10 @@ func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Pre hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.getAdvertisedRoutes(machine) +} + +func (hsdb *HSDatabase) getAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { routes := types.Routes{} err := hsdb.db. @@ -589,6 +580,10 @@ func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.getEnabledRoutes(machine) +} + +func (hsdb *HSDatabase) getEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { routes := types.Routes{} err := hsdb.db. @@ -622,7 +617,7 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) return false } - enabledRoutes, err := hsdb.GetEnabledRoutes(machine) + enabledRoutes, err := hsdb.getEnabledRoutes(machine) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") @@ -654,7 +649,7 @@ func (hsdb *HSDatabase) ListOnlineMachines( hsdb.mu.RLock() defer hsdb.mu.RUnlock() - peers, err := hsdb.ListPeers(machine) + peers, err := hsdb.listPeers(machine) if err != nil { return nil, err } @@ -664,10 +659,6 @@ func (hsdb *HSDatabase) ListOnlineMachines( // enableRoutes enables new routes based on a list of new routes. func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() - newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) @@ -678,7 +669,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string newRoutes[index] = route } - advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine) + advertisedRoutes, err := hsdb.getAdvertisedRoutes(machine) if err != nil { return err } @@ -761,7 +752,7 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - machines, err := hsdb.ListMachinesByGivenName(givenName) + machines, err := hsdb.listMachinesByGivenName(givenName) if err != nil { return "", err } @@ -781,11 +772,10 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string } func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - users, err := hsdb.ListUsers() + users, err := hsdb.listUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -793,7 +783,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati } for _, user := range users { - machines, err := hsdb.ListMachinesByUser(user.Name) + machines, err := hsdb.listMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -814,7 +804,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati Str("machine", machine.Hostname). Msg("Ephemeral client removed from database") - err = hsdb.HardDeleteMachine(&machines[idx]) + err = hsdb.deleteMachine(&machines[idx]) if err != nil { log.Error(). Err(err). @@ -834,16 +824,15 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati } func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() // use the time of the start of the function to ensure we // dont miss some machines by returning it _after_ we have // checked everything. started := time.Now() - users, err := hsdb.ListUsers() + users, err := hsdb.listUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -851,7 +840,7 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { } for _, user := range users { - machines, err := hsdb.ListMachinesByUser(user.Name) + machines, err := hsdb.listMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -867,7 +856,8 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { machine.Expiry.After(lastCheck) { expired = append(expired, tailcfg.NodeID(machine.ID)) - err := hsdb.ExpireMachine(&machines[index]) + now := time.Now() + err := hsdb.machineSetExpiry(&machines[index], now) if err != nil { log.Error(). Err(err). diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go index 319415f3ed..7f837e06db 100644 --- a/hscontrol/db/machine_test.go +++ b/hscontrol/db/machine_test.go @@ -127,28 +127,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { c.Assert(err, check.IsNil) } -func (s *Suite) TestDeleteMachine(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - machine := types.Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(1), - } - db.db.Save(&machine) - - err = db.DeleteMachine(&machine) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine(user.Name, "testmachine") - c.Assert(err, check.NotNil) -} - func (s *Suite) TestHardDeleteMachine(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -164,7 +142,7 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { } db.db.Save(&machine) - err = db.HardDeleteMachine(&machine) + err = db.DeleteMachine(&machine) c.Assert(err, check.IsNil) _, err = db.GetMachine(user.Name, "testmachine3") @@ -329,7 +307,8 @@ func (s *Suite) TestExpireMachine(c *check.C) { c.Assert(machineFromDB.IsExpired(), check.Equals, false) - err = db.ExpireMachine(machineFromDB) + now := time.Now() + err = db.MachineSetExpiry(machineFromDB, now) c.Assert(err, check.IsNil) c.Assert(machineFromDB.IsExpired(), check.Equals, true) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 65285d3f80..ec7ab232f4 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -99,7 +99,11 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er hsdb.mu.RLock() defer hsdb.mu.RUnlock() - user, err := hsdb.GetUser(userName) + return hsdb.listPreAuthKeys(userName) +} + +func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { + user, err := hsdb.getUser(userName) if err != nil { return nil, err } @@ -132,10 +136,13 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.destroyPreAuthKey(pak) +} +func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { return hsdb.db.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error @@ -197,7 +204,10 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) } machines := types.Machines{} - if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + if err := hsdb.db. + Preload("AuthKey"). + Where(&types.Machine{AuthKeyID: uint(pak.ID)}). + Find(&machines).Error; err != nil { return nil, err } diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index a62985d93c..26a08f3702 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -16,6 +16,10 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.getRoutes() +} + +func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { var routes types.Routes err := hsdb.db.Preload("Machine").Find(&routes).Error if err != nil { @@ -29,6 +33,10 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.getMachineAdvertisedRoutes(machine) +} + +func (hsdb *HSDatabase) getMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -42,10 +50,13 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type } func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { - // TODO(kradalby): figure out this lock - // hsdb.mu.RLock() - // defer hsdb.mu.RUnlock() + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + return hsdb.getMachineRoutes(m) +} + +func (hsdb *HSDatabase) getMachineRoutes(m *types.Machine) (types.Routes, error) { var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -62,6 +73,10 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.getRoute(id) +} + +func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { var route types.Route err := hsdb.db.Preload("Machine").First(&route, id).Error if err != nil { @@ -72,11 +87,14 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { } func (hsdb *HSDatabase) EnableRoute(id uint64) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.enableRoute(id) +} - route, err := hsdb.GetRoute(id) +func (hsdb *HSDatabase) enableRoute(id uint64) error { + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -96,11 +114,10 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { } func (hsdb *HSDatabase) DisableRoute(id uint64) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - route, err := hsdb.GetRoute(id) + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -116,10 +133,10 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := hsdb.GetMachineRoutes(&route.Machine) + routes, err := hsdb.getMachineRoutes(&route.Machine) if err != nil { return err } @@ -135,15 +152,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { } } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } func (hsdb *HSDatabase) DeleteRoute(id uint64) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - route, err := hsdb.GetRoute(id) + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -156,10 +172,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := hsdb.GetMachineRoutes(&route.Machine) + routes, err := hsdb.getMachineRoutes(&route.Machine) if err != nil { return err } @@ -175,15 +191,11 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } -func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() - - routes, err := hsdb.GetMachineRoutes(m) +func (hsdb *HSDatabase) deleteMachineRoutes(m *types.Machine) error { + routes, err := hsdb.getMachineRoutes(m) if err != nil { return err } @@ -194,14 +206,11 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { } } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } // isUniquePrefix returns if there is another machine providing the same route already. func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - var count int64 hsdb.db. Model(&types.Route{}). @@ -214,9 +223,6 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { } func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - var route types.Route err := hsdb.db. Preload("Machine"). @@ -252,10 +258,13 @@ func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, } func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.processMachineRoutes(machine) +} +func (hsdb *HSDatabase) processMachineRoutes(machine *types.Machine) error { currentRoutes := types.Routes{} err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { @@ -306,10 +315,13 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { } func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.handlePrimarySubnetFailover() +} +func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { // first, get all the enabled routes var routes types.Routes err := hsdb.db. @@ -434,15 +446,14 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( aclPolicy *policy.ACLPolicy, machine *types.Machine, ) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() if len(machine.IPAddresses) == 0 { return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } - routes, err := hsdb.GetMachineAdvertisedRoutes(machine) + routes, err := hsdb.getMachineAdvertisedRoutes(machine) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error(). Caller(). @@ -495,7 +506,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( } for _, approvedRoute := range approvedRoutes { - err := hsdb.EnableRoute(uint64(approvedRoute.ID)) + err := hsdb.enableRoute(uint64(approvedRoute.ID)) if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 883b856b98..5af4660b7c 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -45,16 +45,15 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { // DestroyUser destroys a User. Returns error if the User does // not exist or if there are machines associated with it. func (hsdb *HSDatabase) DestroyUser(name string) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - user, err := hsdb.GetUser(name) + user, err := hsdb.getUser(name) if err != nil { return ErrUserNotFound } - machines, err := hsdb.ListMachinesByUser(name) + machines, err := hsdb.listMachinesByUser(name) if err != nil { return err } @@ -62,12 +61,12 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := hsdb.ListPreAuthKeys(name) + keys, err := hsdb.listPreAuthKeys(name) if err != nil { return err } for _, key := range keys { - err = hsdb.DestroyPreAuthKey(key) + err = hsdb.destroyPreAuthKey(key) if err != nil { return err } @@ -83,12 +82,11 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() var err error - oldUser, err := hsdb.GetUser(oldName) + oldUser, err := hsdb.getUser(oldName) if err != nil { return err } @@ -96,7 +94,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { if err != nil { return err } - _, err = hsdb.GetUser(newName) + _, err = hsdb.getUser(newName) if err == nil { return ErrUserExists } @@ -118,6 +116,10 @@ func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.getUser(name) +} + +func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { user := types.User{} if result := hsdb.db.First(&user, "name = ?", name); errors.Is( result.Error, @@ -134,6 +136,10 @@ func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.listUsers() +} + +func (hsdb *HSDatabase) listUsers() ([]types.User, error) { users := []types.User{} if err := hsdb.db.Find(&users).Error; err != nil { return nil, err @@ -147,11 +153,15 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.listMachinesByUser(name) +} + +func (hsdb *HSDatabase) listMachinesByUser(name string) (types.Machines, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } - user, err := hsdb.GetUser(name) + user, err := hsdb.getUser(name) if err != nil { return nil, err } @@ -164,17 +174,16 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) return machines, nil } -// SetMachineUser assigns a Machine to a user. -func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() +// AssignMachineToUser assigns a Machine to a user. +func (hsdb *HSDatabase) AssignMachineToUser(machine *types.Machine, username string) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() err := util.CheckForFQDNRules(username) if err != nil { return err } - user, err := hsdb.GetUser(username) + user, err := hsdb.getUser(username) if err != nil { return err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index bc468b2361..97b3e6d7f6 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -114,15 +114,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) { db.db.Save(&machine) c.Assert(machine.UserID, check.Equals, oldUser.ID) - err = db.SetMachineUser(&machine, newUser.Name) + err = db.AssignMachineToUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) - err = db.SetMachineUser(&machine, "non-existing-user") + err = db.AssignMachineToUser(&machine, "non-existing-user") c.Assert(err, check.Equals, ErrUserNotFound) - err = db.SetMachineUser(&machine, newUser.Name) + err = db.AssignMachineToUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 74950c207f..292c8d84a4 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -275,8 +275,11 @@ func (api headscaleV1APIServer) ExpireMachine( return nil, err } - api.h.db.ExpireMachine( + now := time.Now() + + api.h.db.MachineSetExpiry( machine, + now, ) log.Trace(). @@ -358,7 +361,7 @@ func (api headscaleV1APIServer) MoveMachine( return nil, err } - err = api.h.db.SetMachineUser(machine, request.GetUser()) + err = api.h.db.AssignMachineToUser(machine, request.GetUser()) if err != nil { return nil, err } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 663838381b..010bcb15c5 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -523,7 +523,7 @@ func (h *Headscale) validateMachineForOIDCCallback( Str("machine", machine.Hostname). Msg("machine already registered, reauthenticating") - err := h.db.RefreshMachine(machine, expiry) + err := h.db.MachineSetExpiry(machine, expiry) if err != nil { util.LogErr(err, "Failed to refresh machine") http.Error( diff --git a/hscontrol/poll.go b/hscontrol/poll.go index bf7a0f49fb..77161fce0d 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -107,6 +107,7 @@ func (h *Headscale) handlePoll( machine.LastSeen = &now } + // TODO(kradalby): Save specific stuff, not whole object. if err := h.db.MachineSave(machine); err != nil { logErr(err, "Failed to persist/update machine in the database") http.Error(writer, "", http.StatusInternalServerError)