Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into becca/patchelf
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany committed Nov 22, 2023
2 parents 50b30be + e21f39f commit fcedb84
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 67 deletions.
5 changes: 4 additions & 1 deletion cmd/launcher/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ func runLauncher(ctx context.Context, cancel func(), slogger, systemSlogger *mul
// and pass it to the various systems.
var logShipper *logshipper.LogShipper
if k.ControlServerURL() != "" {
// Set log shipping level to debug for the first X minutes of
// run time. This will also increase the sending frequency.
k.SetLogShippingLevelOverride("debug", 10*time.Minute)

logShipper = logshipper.New(k, logger)
logger = teelogger.New(logger, logShipper)
logger = log.With(logger, "caller", log.Caller(5))
Expand Down Expand Up @@ -357,7 +361,6 @@ func runLauncher(ctx context.Context, cancel func(), slogger, systemSlogger *mul
if logShipper != nil {
runGroup.Add("logShipper", logShipper.Run, logShipper.Stop)
controlService.RegisterSubscriber(authTokensSubsystemName, logShipper)
controlService.RegisterSubscriber(agentFlagsSubsystemName, logShipper)
}

if metadataWriter := internal.NewMetadataWriter(logger, k); metadataWriter == nil {
Expand Down
70 changes: 65 additions & 5 deletions ee/desktop/runner/runner_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (r *DesktopUsersProcessesRunner) userEnvVars(ctx context.Context, uid strin

sessionType := strings.Trim(string(typeOutput), "\n")
if sessionType == "x11" {
envVars["DISPLAY"] = r.displayFromX11(ctx, session)
envVars["DISPLAY"] = r.displayFromX11(ctx, session, int32(uidInt))
break
} else if sessionType == "wayland" {
envVars["DISPLAY"] = r.displayFromXwayland(ctx, int32(uidInt))
Expand All @@ -164,33 +164,93 @@ func (r *DesktopUsersProcessesRunner) userEnvVars(ctx context.Context, uid strin
return envVars
}

func (r *DesktopUsersProcessesRunner) displayFromX11(ctx context.Context, session string) string {
func (r *DesktopUsersProcessesRunner) displayFromX11(ctx context.Context, session string, uid int32) string {
// We can read $DISPLAY from the session properties
cmd, err := allowedcmd.Loginctl(ctx, "show-session", session, "--value", "--property=Display")
if err != nil {
level.Debug(r.logger).Log(
"msg", "could not create command to get Display from user session",
"err", err,
)
return defaultDisplay
return r.displayFromXDisplayServerProcess(ctx, uid)
}
xDisplayOutput, err := cmd.Output()
if err != nil {
level.Debug(r.logger).Log(
"msg", "could not get Display from user session",
"err", err,
)
return defaultDisplay
return r.displayFromXDisplayServerProcess(ctx, uid)
}

display := strings.Trim(string(xDisplayOutput), "\n")
if display == "" {
return defaultDisplay
return r.displayFromXDisplayServerProcess(ctx, uid)
}

return display
}

func (r *DesktopUsersProcessesRunner) displayFromXDisplayServerProcess(ctx context.Context, uid int32) string {
processes, err := process.ProcessesWithContext(ctx)
if err != nil {
level.Debug(r.logger).Log(
"msg", "could not query processes to find display server process",
"err", err,
)
return defaultDisplay
}

for _, p := range processes {
cmdline, err := p.CmdlineWithContext(ctx)
if err != nil {
level.Debug(r.logger).Log(
"msg", "could not get cmdline slice for process",
"err", err,
)
continue
}

if !strings.Contains(cmdline, "Xorg") && !strings.Contains(cmdline, "Xvfb") {
continue
}

// We have an Xorg or Xvfb process -- check to make sure it's for our running user
uids, err := p.UidsWithContext(ctx)
if err != nil {
level.Debug(r.logger).Log(
"msg", "could not get uids for process",
"err", err,
)
continue
}
uidMatch := false
for _, procUid := range uids {
if procUid == uid {
uidMatch = true
break
}
}

if uidMatch {
// We have a match! Grab the display value.
// The Xorg process looks like:
// /usr/lib/xorg/Xorg :20 -auth /home/<user>/.Xauthority -nolisten tcp -noreset -logfile /dev/null -verbose 3 -config /tmp/chrome_remote_desktop_j5rldjlk.conf
// The Xvfb process looks like:
// Xvfb :20 -auth /home/<user>/.Xauthority -nolisten tcp -noreset -screen 0 3840x2560x24
cmdlineArgs := strings.Split(cmdline, " ")
if len(cmdlineArgs) < 2 {
// Process is somehow malformed or not what we're looking for -- continue so we can evaluate the following process
continue
}

return cmdlineArgs[1]
}
}

return defaultDisplay
}

func (r *DesktopUsersProcessesRunner) displayFromXwayland(ctx context.Context, uid int32) string {
//For wayland, DISPLAY is not included in loginctl show-session output -- in GNOME,
// Mutter spawns Xwayland and sets $DISPLAY at the same time. Find $DISPLAY by finding
Expand Down
93 changes: 57 additions & 36 deletions pkg/agent/flags/flag_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@ import (
// FlagController is responsible for retrieving flag values from the appropriate sources,
// determining precedence, sanitizing flag values, and notifying observers of changes.
type FlagController struct {
logger log.Logger
cmdLineOpts *launcher.Options
agentFlagsStore types.KVStore
overrideMutex sync.RWMutex
controlRequestOverride FlagValueOverride
observers map[types.FlagsChangeObserver][]keys.FlagKey
observersMutex sync.RWMutex
logger log.Logger
cmdLineOpts *launcher.Options
agentFlagsStore types.KVStore
overrideMutex sync.RWMutex
overrides map[keys.FlagKey]*Override
observers map[types.FlagsChangeObserver][]keys.FlagKey
observersMutex sync.RWMutex
}

func NewFlagController(logger log.Logger, agentFlagsStore types.KVStore, opts ...Option) *FlagController {
fc := &FlagController{
logger: logger,
logger: log.With(logger, "component", "flag_controller"),
cmdLineOpts: &launcher.Options{},
agentFlagsStore: agentFlagsStore,
observers: make(map[types.FlagsChangeObserver][]keys.FlagKey),
overrides: make(map[keys.FlagKey]*Override),
}

for _, opt := range opts {
Expand Down Expand Up @@ -114,6 +115,44 @@ func (fc *FlagController) notifyObservers(flagKeys ...keys.FlagKey) {
}
}

func (fc *FlagController) overrideFlag(key keys.FlagKey, duration time.Duration, value any) {
// Always notify observers when overrides start, so they know to refresh.
// Defering this before defering unlocking the mutex so that notifications occur outside of the critical section.
defer fc.notifyObservers(key)

fc.overrideMutex.Lock()
defer fc.overrideMutex.Unlock()

level.Info(fc.logger).Log(
"msg", "overriding flag",
"key", key,
"value", value,
"duration", duration,
)

override, ok := fc.overrides[key]
if !ok || override.Value() == nil {
// Creating the override implicitly causes future flag value retrievals to use the override until expiration
override = &Override{}
fc.overrides[key] = override
}

overrideExpired := func(key keys.FlagKey) {
// Always notify observers when overrides expire, so they know to refresh.
// Defering this before defering unlocking the mutex so that notifications occur outside of the critical section.
defer fc.notifyObservers(key)

fc.overrideMutex.Lock()
defer fc.overrideMutex.Unlock()

// Deleting the override implictly allows the next value to take precedence
delete(fc.overrides, key)
}

// Start a new override, or re-start an existing one with a new value, duration, and expiration
fc.overrides[key].Start(key, value, duration, overrideExpired)
}

func (fc *FlagController) AutoloadedExtensions() []string {
return fc.cmdLineOpts.AutoloadedExtensions
}
Expand Down Expand Up @@ -250,40 +289,15 @@ func (fc *FlagController) ControlServerURL() string {
func (fc *FlagController) SetControlRequestInterval(interval time.Duration) error {
return fc.setControlServerValue(keys.ControlRequestInterval, durationToBytes(interval))
}
func (fc *FlagController) SetControlRequestIntervalOverride(interval, duration time.Duration) {
// Always notify observers when overrides start, so they know to refresh.
// Defering this before defering unlocking the mutex so that notifications occur outside of the critical section.
defer fc.notifyObservers(keys.ControlRequestInterval)

fc.overrideMutex.Lock()
defer fc.overrideMutex.Unlock()

if fc.controlRequestOverride == nil || fc.controlRequestOverride.Value() == nil {
// Creating the override implicitly causes future ControlRequestInterval retrievals to use the override until expiration
fc.controlRequestOverride = &Override{}
}

overrideExpired := func(key keys.FlagKey) {
// Always notify observers when overrides expire, so they know to refresh.
// Defering this before defering unlocking the mutex so that notifications occur outside of the critical section.
defer fc.notifyObservers(key)

fc.overrideMutex.Lock()
defer fc.overrideMutex.Unlock()

// Deleting the override implictly allows the next value to take precedence
fc.controlRequestOverride = nil
}

// Start a new override, or re-start an existing one with a new value, duration, and expiration
fc.controlRequestOverride.Start(keys.ControlRequestInterval, interval, duration, overrideExpired)
func (fc *FlagController) SetControlRequestIntervalOverride(value time.Duration, duration time.Duration) {
fc.overrideFlag(keys.ControlRequestInterval, duration, value)
}
func (fc *FlagController) ControlRequestInterval() time.Duration {
fc.overrideMutex.RLock()
defer fc.overrideMutex.RUnlock()

return NewDurationFlagValue(fc.logger, keys.ControlRequestInterval,
WithOverride(fc.controlRequestOverride),
WithOverride(fc.overrides[keys.ControlRequestInterval]),
WithDefault(fc.cmdLineOpts.ControlRequestInterval),
WithMin(5*time.Second),
WithMax(10*time.Minute),
Expand Down Expand Up @@ -490,10 +504,17 @@ func (fc *FlagController) LogIngestServerURL() string {
func (fc *FlagController) SetLogShippingLevel(level string) error {
return fc.setControlServerValue(keys.LogShippingLevel, []byte(level))
}
func (fc *FlagController) SetLogShippingLevelOverride(value string, duration time.Duration) {
fc.overrideFlag(keys.LogShippingLevel, duration, value)
}
func (fc *FlagController) LogShippingLevel() string {
fc.overrideMutex.RLock()
defer fc.overrideMutex.RUnlock()

const defaultLevel = "info"

return NewStringFlagValue(
WithOverrideString(fc.overrides[keys.LogShippingLevel]),
WithDefaultString(defaultLevel),
WithSanitizer(func(value string) string {
value = strings.ToLower(value)
Expand Down
16 changes: 16 additions & 0 deletions pkg/agent/flags/flag_value_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ package flags

type stringFlagValueOption func(*stringFlagValue)

func WithOverrideString(override FlagValueOverride) stringFlagValueOption {
return func(s *stringFlagValue) {
s.override = override
}
}

func WithDefaultString(defaultVal string) stringFlagValueOption {
return func(s *stringFlagValue) {
s.defaultVal = defaultVal
Expand All @@ -17,6 +23,7 @@ func WithSanitizer(sanitizer func(value string) string) stringFlagValueOption {
type stringFlagValue struct {
defaultVal string
sanitizer func(value string) string
override FlagValueOverride
}

func NewStringFlagValue(opts ...stringFlagValueOption) *stringFlagValue {
Expand All @@ -35,9 +42,18 @@ func (s *stringFlagValue) get(controlServerValue []byte) string {
stringValue = string(controlServerValue)
}

if s.override != nil && s.override.Value() != nil {
// An override was provided, if it's valid let it take precedence
value, ok := s.override.Value().(string)
if ok {
stringValue = value
}
}

// Run the string through a sanitizer, if one was provided
if s.sanitizer != nil {
stringValue = s.sanitizer(stringValue)
}

return stringValue
}
11 changes: 9 additions & 2 deletions pkg/agent/flags/flag_value_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package flags
import (
"testing"

"github.com/kolide/launcher/pkg/agent/flags/mocks"
"github.com/stretchr/testify/assert"
)

func TestFlagValueString(t *testing.T) {
t.Parallel()

// mockOverride := mocks.NewFlagValueOverride(t)
// mockOverride.On("Value").Return(7 * time.Second)
mockOverride := mocks.NewFlagValueOverride(t)
mockOverride.On("Value").Return("override_value")

tests := []struct {
name string
Expand Down Expand Up @@ -52,6 +53,12 @@ func TestFlagValueString(t *testing.T) {
controlServerValue: []byte("control-server-says-this"),
expected: "SANITIZED control-server-says-this",
},
{
name: "control server with override",
options: []stringFlagValueOption{WithDefaultString("default_value"), WithOverrideString(mockOverride)},
controlServerValue: []byte("enabled"),
expected: mockOverride.Value().(string),
},
}
for _, tt := range tests {
tt := tt
Expand Down
3 changes: 3 additions & 0 deletions pkg/agent/knapsack/knapsack.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ func (k *knapsack) LogIngestServerURL() string {
func (k *knapsack) SetLogShippingLevel(level string) error {
return k.flags.SetLogShippingLevel(level)
}
func (k *knapsack) SetLogShippingLevelOverride(value string, duration time.Duration) {
k.flags.SetLogShippingLevelOverride(value, duration)
}
func (k *knapsack) LogShippingLevel() string {
return k.flags.LogShippingLevel()
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/agent/types/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ type Flags interface {
// ControlRequestInterval is the interval at which control client will check for updates from the control server.
SetControlRequestInterval(interval time.Duration) error
// SetControlRequestIntervalOverride stores an interval to be temporarily used as an override of any other interval, until the duration has elapased.
SetControlRequestIntervalOverride(interval, duration time.Duration)
SetControlRequestIntervalOverride(value time.Duration, duration time.Duration)
ControlRequestInterval() time.Duration

// DisableControlTLS disables TLS transport with the control server.
Expand Down Expand Up @@ -183,6 +183,7 @@ type Flags interface {

// LogShippingLevel is the level at which logs should be shipped to the server
SetLogShippingLevel(level string) error
SetLogShippingLevelOverride(value string, duration time.Duration)
LogShippingLevel() string

// TraceIngestServerURL is the URL of the ingest server for traces
Expand Down
11 changes: 8 additions & 3 deletions pkg/agent/types/mocks/flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit fcedb84

Please sign in to comment.