Skip to content

Commit

Permalink
Make trace batch timeout a control server flag
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany committed Oct 25, 2023
1 parent 1b3249d commit ef8e404
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 11 deletions.
11 changes: 11 additions & 0 deletions pkg/agent/flags/flag_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,17 @@ func (fc *FlagController) TraceSamplingRate() float64 {
).get(fc.getControlServerValue(keys.TraceSamplingRate))
}

func (fc *FlagController) SetTraceBatchTimeout(duration time.Duration) error {
return fc.setControlServerValue(keys.TraceBatchTimeout, durationToBytes(duration))
}
func (fc *FlagController) TraceBatchTimeout() time.Duration {
return NewDurationFlagValue(fc.logger, keys.TraceBatchTimeout,
WithDefault(fc.cmdLineOpts.TraceBatchTimeout),
WithMin(5*time.Second),
WithMax(1*time.Hour),
).get(fc.getControlServerValue(keys.TraceBatchTimeout))
}

func (fc *FlagController) SetLogIngestServerURL(url string) error {
return fc.setControlServerValue(keys.LogIngestServerURL, []byte(url))
}
Expand Down
1 change: 1 addition & 0 deletions pkg/agent/flags/keys/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const (
UpdateDirectory FlagKey = "update_directory"
ExportTraces FlagKey = "export_traces"
TraceSamplingRate FlagKey = "trace_sampling_rate"
TraceBatchTimeout FlagKey = "trace_batch_timeout"
LogIngestServerURL FlagKey = "log_ingest_url"
TraceIngestServerURL FlagKey = "trace_ingest_url"
DisableTraceIngestTLS FlagKey = "disable_trace_ingest_tls"
Expand Down
7 changes: 7 additions & 0 deletions pkg/agent/knapsack/knapsack.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ func (k *knapsack) DisableTraceIngestTLS() bool {
return k.flags.DisableTraceIngestTLS()
}

func (k *knapsack) SetTraceBatchTimeout(duration time.Duration) error {
return k.flags.SetTraceBatchTimeout(duration)
}
func (k *knapsack) TraceBatchTimeout() time.Duration {
return k.flags.TraceBatchTimeout()
}

func (k *knapsack) SetLogIngestServerURL(url string) error {
return k.flags.SetLogIngestServerURL(url)
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/agent/types/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ type Flags interface {
SetDisableTraceIngestTLS(enabled bool) error
DisableTraceIngestTLS() bool

// TraceBatchTimeout is the maximum amount of time before the trace exporter will export the next batch of spans
SetTraceBatchTimeout(duration time.Duration) error
TraceBatchTimeout() time.Duration

// InModernStandby indicates whether a Windows machine is awake or in modern standby
SetInModernStandby(enabled bool) error
InModernStandby() bool
Expand Down
28 changes: 28 additions & 0 deletions pkg/agent/types/mocks/flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions pkg/agent/types/mocks/knapsack.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions pkg/launcher/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ type Options struct {
ExportTraces bool
// TraceSamplingRate is a number between 0.0 and 1.0 that indicates what fraction of traces should be sampled.
TraceSamplingRate float64
// TraceBatchTimeout is the maximum amount of time before the trace exporter will export the next batch of spans
TraceBatchTimeout time.Duration
// LogIngestServerURL is the URL that logs and other observability data will be exported to
LogIngestServerURL string
// TraceIngestServerURL is the URL that traces will be exported to
Expand Down Expand Up @@ -221,6 +223,7 @@ func ParseOptions(subcommandName string, args []string) (*Options, error) {
flConfigFilePath = flagset.String("config", DefaultConfigFilePath, "config file to parse options from (optional)")
flExportTraces = flagset.Bool("export_traces", false, "Whether to export traces")
flTraceSamplingRate = flagset.Float64("trace_sampling_rate", 0.0, "What fraction of traces should be sampled")
flTraceBatchTimeout = flagset.Duration("trace_batch_timeout", 1*time.Minute, "Maximum amount of time before the trace exporter will export the next batch of spans")
flLogIngestServerURL = flagset.String("log_ingest_url", "", "Where to export logs")
flTraceIngestServerURL = flagset.String("trace_ingest_url", "", "Where to export traces")
flDisableIngestTLS = flagset.Bool("disable_trace_ingest_tls", false, "Disable TLS for observability ingest server communication")
Expand Down Expand Up @@ -407,6 +410,7 @@ func ParseOptions(subcommandName string, args []string) (*Options, error) {
RootDirectory: *flRootDirectory,
RootPEM: *flRootPEM,
TraceSamplingRate: *flTraceSamplingRate,
TraceBatchTimeout: *flTraceBatchTimeout,
Transport: *flTransport,
UpdateChannel: updateChannel,
UpdateDirectory: *flUpdateDirectory,
Expand Down
1 change: 1 addition & 0 deletions pkg/launcher/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func getArgsAndResponse() (map[string]string, *Options) {
ControlRequestInterval: 60 * time.Second,
ExportTraces: false,
TraceSamplingRate: 0.0,
TraceBatchTimeout: 1 * time.Minute,
LogIngestServerURL: "",
DisableTraceIngestTLS: false,
KolideServerURL: randomHostname,
Expand Down
24 changes: 16 additions & 8 deletions pkg/traces/exporter/exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ import (
"google.golang.org/grpc"
)

const (
applicationName = "launcher"
batchTimeout = 1 * time.Minute // ensure traces are exported at least once per minute -- otel default is 5 sec
)
const applicationName = "launcher"

var archAttributeMap = map[string]attribute.KeyValue{
"amd64": semconv.HostArchAMD64,
Expand Down Expand Up @@ -58,6 +55,7 @@ type TraceExporter struct {
disableIngestTLS bool
enabled bool
traceSamplingRate float64
batchTimeout time.Duration
ctx context.Context
cancel context.CancelFunc
interrupted bool
Expand Down Expand Up @@ -93,13 +91,14 @@ func NewTraceExporter(ctx context.Context, k types.Knapsack, client osquery.Quer
disableIngestTLS: k.DisableTraceIngestTLS(),
enabled: k.ExportTraces(),
traceSamplingRate: k.TraceSamplingRate(),
batchTimeout: k.TraceBatchTimeout(),
ctx: ctx,
cancel: cancel,
}

// Observe ExportTraces and IngestServerURL changes to know when to start/stop exporting, and where
// to export to
t.knapsack.RegisterChangeObserver(t, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS)
// Observe changes to trace configuration to know when to start/stop exporting, and when
// to adjust exporting behavior
t.knapsack.RegisterChangeObserver(t, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout)

if !t.enabled {
return t, nil
Expand Down Expand Up @@ -254,7 +253,7 @@ func (t *TraceExporter) setNewGlobalProvider() {
parentBasedSampler := sdktrace.ParentBased(sdktrace.TraceIDRatioBased(t.traceSamplingRate))

newProvider := sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exp, sdktrace.WithBatchTimeout(batchTimeout)),
sdktrace.WithBatcher(exp, sdktrace.WithBatchTimeout(t.batchTimeout)),
sdktrace.WithResource(r),
sdktrace.WithSampler(parentBasedSampler),
)
Expand Down Expand Up @@ -359,6 +358,15 @@ func (t *TraceExporter) FlagsChanged(flagKeys ...keys.FlagKey) {
}
}

// Handle trace_batch_timeout updates
if slices.Contains(flagKeys, keys.TraceBatchTimeout) {
if t.batchTimeout != t.knapsack.TraceBatchTimeout() {
t.batchTimeout = t.knapsack.TraceBatchTimeout()
needsNewProvider = true
level.Debug(t.logger).Log("msg", "updating trace batch timeout", "new_batch_timeout", t.batchTimeout)
}
}

if !t.enabled || !needsNewProvider {
return
}
Expand Down
78 changes: 75 additions & 3 deletions pkg/traces/exporter/exporter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ func TestNewTraceExporter(t *testing.T) { //nolint:paralleltest
mockKnapsack.On("DisableTraceIngestTLS").Return(false)
mockKnapsack.On("ExportTraces").Return(true)
mockKnapsack.On("TraceSamplingRate").Return(1.0)
mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS).Return(nil)
mockKnapsack.On("TraceBatchTimeout").Return(1 * time.Minute)
mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout).Return(nil)

osqueryClient := mocks.NewQuerier(t)
osqueryClient.On("Query", mock.Anything).Return([]map[string]string{
Expand Down Expand Up @@ -85,7 +86,8 @@ func TestNewTraceExporter_exportNotEnabled(t *testing.T) {
mockKnapsack.On("DisableTraceIngestTLS").Return(false)
mockKnapsack.On("ExportTraces").Return(false)
mockKnapsack.On("TraceSamplingRate").Return(0.0)
mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS).Return(nil)
mockKnapsack.On("TraceBatchTimeout").Return(1 * time.Minute)
mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout).Return(nil)

traceExporter, err := NewTraceExporter(context.Background(), mockKnapsack, mocks.NewQuerier(t), log.NewNopLogger())
require.NoError(t, err)
Expand Down Expand Up @@ -122,7 +124,8 @@ func TestInterrupt_Multiple(t *testing.T) {
mockKnapsack.On("DisableTraceIngestTLS").Return(false)
mockKnapsack.On("ExportTraces").Return(false)
mockKnapsack.On("TraceSamplingRate").Return(0.0)
mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS).Return(nil)
mockKnapsack.On("TraceBatchTimeout").Return(1 * time.Minute)
mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout).Return(nil)

traceExporter, err := NewTraceExporter(context.Background(), mockKnapsack, mocks.NewQuerier(t), log.NewNopLogger())
require.NoError(t, err)
Expand Down Expand Up @@ -603,6 +606,75 @@ func TestFlagsChanged_DisableTraceIngestTLS(t *testing.T) { //nolint:paralleltes
}
}

func TestFlagsChanged_TraceBatchTimeout(t *testing.T) { //nolint:paralleltest
tests := []struct {
testName string
currentBatchTimeout time.Duration
newBatchTimeout time.Duration
tracingEnabled bool
shouldReplaceProvider bool
}{
{
testName: "update",
currentBatchTimeout: 1 * time.Minute,
newBatchTimeout: 5 * time.Second,
tracingEnabled: true,
shouldReplaceProvider: true,
},
{
testName: "update but tracing not enabled",
currentBatchTimeout: 1 * time.Minute,
newBatchTimeout: 5 * time.Second,
tracingEnabled: false,
shouldReplaceProvider: false,
},
{
testName: "no update",
currentBatchTimeout: 1 * time.Minute,
newBatchTimeout: 1 * time.Minute,
tracingEnabled: true,
shouldReplaceProvider: false,
},
}

for _, tt := range tests { //nolint:paralleltest
tt := tt
t.Run(tt.testName, func(t *testing.T) {
mockKnapsack := typesmocks.NewKnapsack(t)
mockKnapsack.On("TraceBatchTimeout").Return(tt.newBatchTimeout)
osqueryClient := mocks.NewQuerier(t)

ctx, cancel := context.WithCancel(context.Background())
traceExporter := &TraceExporter{
knapsack: mockKnapsack,
osqueryClient: osqueryClient,
logger: log.NewNopLogger(),
attrs: make([]attribute.KeyValue, 0),
attrLock: sync.RWMutex{},
ingestClientAuthenticator: newClientAuthenticator("test token", false),
ingestAuthToken: "test token",
ingestUrl: "localhost:4317",
disableIngestTLS: false,
enabled: tt.tracingEnabled,
traceSamplingRate: 1.0,
batchTimeout: tt.currentBatchTimeout,
ctx: ctx,
cancel: cancel,
}

traceExporter.FlagsChanged(keys.TraceBatchTimeout)

require.Equal(t, tt.newBatchTimeout, traceExporter.batchTimeout, "batch timeout value not updated")

if tt.shouldReplaceProvider {
require.NotNil(t, traceExporter.provider)
} else {
require.Nil(t, traceExporter.provider)
}
})
}
}

func testServerProvidedDataStore(t *testing.T) types.KVStore {
s, err := storageci.NewStore(t, log.NewNopLogger(), storage.ServerProvidedDataStore.String())
require.NoError(t, err)
Expand Down

0 comments on commit ef8e404

Please sign in to comment.