Skip to content

Commit

Permalink
Fix issue with instances sharing config dir.
Browse files Browse the repository at this point in the history
For legacy installations (where a tsnet-tsnsrv directory has been
created) a new "machine-name" fill will be created in the tsnet-tsnsrv
directory if it doesn't already exist. The machine-name will contain the
name that was passed with the -name command line argument.

If the machine-name file already exists and it matches what was passed
with the -name command line argument then that directory will be used as
the configuration directory.

If the tsnet-tsnsrv directory doesn't exist or if the machine-name
doesn't match then a new tsnet-tsnsrv-<name> directory will be used to
store tsnet configuration.

This allows for more that one tsnsrv instance to be started without
having to specify a config directory without needing to set the
TS_STATE_DIR env var or pass the -stateDir flag.

Fixes: boinkor-net#62
  • Loading branch information
Evan Farrer authored and Evan Farrer committed Jan 14, 2024
1 parent 3c51734 commit 86bfe5a
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 1 deletion.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/boinkor-net/tsnsrv
go 1.21

require (
github.com/gofrs/flock v0.8.1
github.com/peterbourgon/ff/v3 v3.4.0
github.com/prometheus/client_golang v1.17.0
github.com/stretchr/testify v1.8.4
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/E
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU=
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand Down
10 changes: 9 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func tailnetSrvFromArgs(args []string) (*validTailnetSrv, *ffcli.Command, error)
fs.DurationVar(&s.Timeout, "timeout", 1*time.Minute, "Timeout connecting to the tailnet")
fs.Var(&s.AllowedPrefixes, "prefix", "Allowed URL prefixes; if none is set, all prefixes are allowed")
fs.BoolVar(&s.StripPrefix, "stripPrefix", true, "Strip prefixes that matched; best set to false if allowing multiple prefixes")
fs.StringVar(&s.StateDir, "stateDir", os.Getenv("TS_STATE_DIR"), "Directory containing the persistent tailscale status files. Can also be set by $TS_STATE_DIR; this option takes precedence.")
fs.StringVar(&s.StateDir, "stateDir", "", "Directory containing the persistent tailscale status files. Can also be set by $TS_STATE_DIR; this option takes precedence.")
fs.StringVar(&s.AuthkeyPath, "authkeyPath", "", "File containing a tailscale auth key. Key is assumed to be in $TS_AUTHKEY in absence of this option.")
fs.BoolVar(&s.InsecureHTTPS, "insecureHTTPS", false, "Disable TLS certificate validation on upstream")
fs.DurationVar(&s.WhoisTimeout, "whoisTimeout", 1*time.Second, "Maximum amount of time to spend looking up client identities")
Expand All @@ -123,6 +123,14 @@ func tailnetSrvFromArgs(args []string) (*validTailnetSrv, *ffcli.Command, error)
if err := root.Parse(args); err != nil {
return nil, root, fmt.Errorf("could not parse args: %w", err)
}

// Figure out the state directory
stateDir, err := NewStateDir(s.Name, s.StateDir).Compute()
if err != nil {
return nil, nil, fmt.Errorf("unable to compute state dir: %w", err)
}
s.StateDir = stateDir

valid, err := s.validate(root.FlagSet.Args())
if err != nil {
return nil, root, fmt.Errorf("failed to validate args: %w", err)
Expand Down
154 changes: 154 additions & 0 deletions state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package main

import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"path"
"strings"
"time"

"github.com/gofrs/flock"
)

type StateDir struct {
machineName string
stateDirFlag string
getEnv func(string) string
userConfigDir func() (string, error)
dirExists func(string) (bool, error)
readFileString func(string) (string, error)
writeFileString func(string, string) error
}

func NewStateDir(machineName, stateDirFlag string) StateDir {
return StateDir{
machineName: machineName,
stateDirFlag: stateDirFlag,
getEnv: os.Getenv,
userConfigDir: os.UserConfigDir,
dirExists: dirExists,
readFileString: readFileString,
writeFileString: writeFileString,
}
}

func (sd StateDir) Compute() (string, error) {
// Set command line flag
if sd.stateDirFlag != "" {
return sd.stateDirFlag, nil
}

// Set TS_STATE_DIR env var
tsStateDirEnv := sd.getEnv("TS_STATE_DIR")
if tsStateDirEnv != "" {
return tsStateDirEnv, nil
}

// Looking for legacy tsnet-tsnsrv configuration directory
userConfigDir, err := sd.userConfigDir()
if err != nil {
return "", fmt.Errorf("unable to find user config directory. %w", err)
}
legacyTsnetConfigDir := path.Join(userConfigDir, "tsnet-tsnsrv")
legacyTsnetDirExists, err := sd.dirExists(legacyTsnetConfigDir)
if err != nil {
return "", fmt.Errorf("unable to determine existence of legacy tsnet config directory. %w", err)
}

// The tsnet-tsnet directory doesn't exist so we can just create a unique configuration directory for the given
// machine name.
if !legacyTsnetDirExists {
return path.Join(userConfigDir, fmt.Sprintf("tsnet-tsnsrv-%s", sd.machineName)), nil
}

// The tsnet-tsnet directory does exist reach the machine name file and see if they match
machineNamePath := path.Join(legacyTsnetConfigDir, "machine-name")
readName, err := sd.readFileString(machineNamePath)
if errors.Is(err, fs.ErrNotExist) {
err = sd.writeFileString(machineNamePath, sd.machineName)
if err != nil {
return "", fmt.Errorf("unable to write machine name to legacy config dir. %w", err)
}

return legacyTsnetConfigDir, nil
}
if err != nil {
return "", fmt.Errorf("unable to read legacy machine-name file. %w", err)
}

if strings.TrimSpace(readName) == sd.machineName {
return legacyTsnetConfigDir, nil
}

return path.Join(userConfigDir, fmt.Sprintf("tsnet-tsnsrv-%s", sd.machineName)), nil
}

func lockFilePath() string {
return path.Join(os.TempDir(), "tsnsrv.lock")
}

var tryLockTimeoutErr = errors.New("timeout trying to get the file lock")

func lockContext(ctx context.Context) context.Context {
ctx, _ = context.WithTimeoutCause(ctx, time.Second*5, tryLockTimeoutErr)
return ctx
}

func tryLock(ctx context.Context, readLock bool) (func() error, error) {
lockFile := lockFilePath()
lock := flock.New(lockFile)
ctx = lockContext(ctx)
lockFn := lock.TryLockContext
if readLock {
lockFn = lock.TryRLockContext
}

locked, err := lockFn(ctx, time.Millisecond*100)
if errors.Is(err, tryLockTimeoutErr) {
return nil, fmt.Errorf("timeout trying to get lock %s another process is using it", lockFile)
}
if err != nil {
return nil, fmt.Errorf("trying to lock %s. %w", lockFile, err)
}
if !locked {
return nil, fmt.Errorf("unable to get lock %s", lockFile)
}

return lock.Unlock, nil
}

func readFileString(file string) (string, error) {
unlocker, err := tryLock(context.Background(), true)
if err != nil {
return "", err
}
defer unlocker()

bytes, err := os.ReadFile(file)
return string(bytes), err
}

func writeFileString(file, contents string) error {
unlocker, err := tryLock(context.Background(), false)
if err != nil {
return err
}
defer unlocker()

return os.WriteFile(file, []byte(contents), 0644)
}

func dirExists(dir string) (bool, error) {
_, err := os.Stat(dir)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}

return false, err
}
131 changes: 131 additions & 0 deletions state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package main

import (
"fmt"
"io/fs"
"path"
"testing"

"github.com/stretchr/testify/require"
)

func initialState() StateDir {
sd := NewStateDir("machine-name", "")
sd.getEnv = func(string) string { return "" }
sd.userConfigDir = func() (string, error) { return "", nil }
sd.dirExists = func(string) (bool, error) { return false, nil }
sd.readFileString = func(string) (string, error) { return "", nil }
sd.writeFileString = func(string, string) error { return nil }

return sd
}

// Ensure that the -stateDir flag is used for selecting the state directory.
func TestStateDirFlag_IsUsedIfSet(t *testing.T) {
t.Parallel()

const stateDirFlag = "some path"

sd := initialState()
sd.stateDirFlag = stateDirFlag

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, stateDirFlag, stateDir)
}

// Ensure that the TS_STATE_DIR environment variable is used for selecting the state directory.
func TestTSSTATEDIREnvVarIsUsedIfSet(t *testing.T) {
t.Parallel()

const stateDirEnv = "some path"

sd := initialState()
sd.getEnv = func(string) string { return stateDirEnv }

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, stateDirEnv, stateDir)
}

// Ensure that the tsnet-tsnsrv is used if it exists and the machine_name file contents match the -name argument.
func TestTsnetTsnsrvDirIsUsedIfExistsAndMachineNameMatches(t *testing.T) {
t.Parallel()

const userConfigDir = "/home/somedir/.config/"
const legacyTsnetConfigDir = "/home/somedir/.config/tsnet-tsnsrv"

sd := initialState()
sd.userConfigDir = func() (string, error) { return userConfigDir, nil }
sd.dirExists = func(dir string) (bool, error) { return true, nil }
sd.readFileString = func(file string) (string, error) { return sd.machineName, nil }

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, legacyTsnetConfigDir, stateDir)
}

// Ensure that the machine_name file is created in tsnet-tsnsrv if it doesn't exist.
func TestMachineNameFileIsCreatedIfNeeded(t *testing.T) {
t.Parallel()

const userConfigDir = "/home/somedir/.config/"
const legacyTsnetConfigDir = "/home/somedir/.config/tsnet-tsnsrv"
machineNameFile := path.Join(legacyTsnetConfigDir, "machine-name")
writeFileStringCalled := false

sd := initialState()
sd.userConfigDir = func() (string, error) { return userConfigDir, nil }
sd.dirExists = func(dir string) (bool, error) { return true, nil }
sd.readFileString = func(file string) (string, error) { return "", fs.ErrNotExist }
sd.writeFileString = func(file, contents string) error {
require.Equal(t, machineNameFile, file)
require.Equal(t, sd.machineName, contents)
writeFileStringCalled = true
return nil
}

stateDir, err := sd.Compute()

require.True(t, writeFileStringCalled)
require.NoError(t, err)
require.Equal(t, legacyTsnetConfigDir, stateDir)
}

// Ensure that tsnet-tsnsrv-<name> is used if a tsnet-tsnsrv directory doesn't exist
func TestTsnetTsnsrvNameIsUsedIfLegacyDirDoesntExist(t *testing.T) {
t.Parallel()

sd := initialState()
const userConfigDir = "/home/somedir/.config/"
newTsnetConfigDir := fmt.Sprintf("/home/somedir/.config/tsnet-tsnsrv-%s", sd.machineName)

sd.userConfigDir = func() (string, error) { return userConfigDir, nil }
sd.dirExists = func(dir string) (bool, error) { return false, nil }

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, newTsnetConfigDir, stateDir)
}

// Ensure that tsnet-tsnsrv-<name> is used if the machine_name doesn't match.
func TestTsnetTsnsrvNameIsUsedIfMachineNameDoesntMatch(t *testing.T) {
t.Parallel()

sd := initialState()
const userConfigDir = "/home/somedir/.config/"
newTsnetConfigDir := fmt.Sprintf("/home/somedir/.config/tsnet-tsnsrv-%s", sd.machineName)

sd.userConfigDir = func() (string, error) { return userConfigDir, nil }
sd.dirExists = func(dir string) (bool, error) { return true, nil }
sd.readFileString = func(file string) (string, error) { return "not-a-match", nil }

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, newTsnetConfigDir, stateDir)
}

0 comments on commit 86bfe5a

Please sign in to comment.