From 80ca0950128916f7d9654bee462bb7fa30495f05 Mon Sep 17 00:00:00 2001 From: Christophe Tafani-Dereeper Date: Thu, 1 Aug 2024 17:19:39 +0200 Subject: [PATCH] Allow specifying context when instantiating Stratus Red Team runners and propagate them down to Terraform (#546) --- v2/pkg/stratus/runner/runner.go | 17 +++++++++++++++-- v2/pkg/stratus/runner/terraform.go | 16 +++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/v2/pkg/stratus/runner/runner.go b/v2/pkg/stratus/runner/runner.go index ecaaa2b28..cde7191e2 100644 --- a/v2/pkg/stratus/runner/runner.go +++ b/v2/pkg/stratus/runner/runner.go @@ -1,6 +1,7 @@ package runner import ( + "context" "errors" "github.com/datadog/stratus-red-team/v2/internal/providers" "github.com/datadog/stratus-red-team/v2/internal/state" @@ -23,9 +24,14 @@ type Runner struct { StateManager state.StateManager ProviderFactory stratus.CloudProviders UniqueCorrelationID uuid.UUID + Context context.Context } func NewRunner(technique *stratus.AttackTechnique, force bool) Runner { + return NewRunnerWithContext(context.Background(), technique, force) +} + +func NewRunnerWithContext(ctx context.Context, technique *stratus.AttackTechnique, force bool) Runner { stateManager := state.NewFileSystemStateManager(technique) uuid := uuid.New() runner := Runner{ @@ -33,9 +39,10 @@ func NewRunner(technique *stratus.AttackTechnique, force bool) Runner { ShouldForce: force, StateManager: stateManager, UniqueCorrelationID: uuid, - TerraformManager: NewTerraformManager( - filepath.Join(stateManager.GetRootDirectory(), "terraform"), providers.GetStratusUserAgentForUUID(uuid), + TerraformManager: NewTerraformManagerWithContext( + ctx, filepath.Join(stateManager.GetRootDirectory(), "terraform"), providers.GetStratusUserAgentForUUID(uuid), ), + Context: ctx, } runner.initialize() @@ -86,6 +93,9 @@ func (m *Runner) WarmUp() (map[string]string, error) { if err != nil { log.Println("Error during warm up. Cleaning up technique prerequisites with terraform destroy") _ = m.TerraformManager.TerraformDestroy(m.TerraformDir) + if errors.Is(err, context.Canceled) { + return nil, err + } return nil, errors.New("unable to run terraform apply on prerequisite: " + errorMessageFromTerraformError(err)) } @@ -186,6 +196,9 @@ func (m *Runner) CleanUp() error { log.Println("Cleaning up technique prerequisites with terraform destroy") err := m.TerraformManager.TerraformDestroy(m.TerraformDir) if err != nil { + if errors.Is(err, context.Canceled) { + return err + } return errors.New("unable to cleanup TTP prerequisites: " + errorMessageFromTerraformError(err)) } } diff --git a/v2/pkg/stratus/runner/terraform.go b/v2/pkg/stratus/runner/terraform.go index e3935a1d9..6bb61cbcf 100644 --- a/v2/pkg/stratus/runner/terraform.go +++ b/v2/pkg/stratus/runner/terraform.go @@ -26,13 +26,19 @@ type TerraformManagerImpl struct { terraformBinaryPath string terraformVersion string terraformUserAgent string + context context.Context } func NewTerraformManager(terraformBinaryPath string, userAgent string) TerraformManager { + return NewTerraformManagerWithContext(context.Background(), terraformBinaryPath, userAgent) +} + +func NewTerraformManagerWithContext(ctx context.Context, terraformBinaryPath string, userAgent string) TerraformManager { manager := TerraformManagerImpl{ terraformVersion: TerraformVersion, terraformBinaryPath: terraformBinaryPath, terraformUserAgent: userAgent, + context: ctx, } manager.Initialize() return &manager @@ -48,7 +54,7 @@ func (m *TerraformManagerImpl) Initialize() { SkipChecksumVerification: false, } log.Println("Installing Terraform in " + m.terraformBinaryPath) - _, err := terraformInstaller.Install(context.Background()) + _, err := terraformInstaller.Install(m.context) if err != nil { log.Fatalf("error installing Terraform: %s", err) } @@ -69,7 +75,7 @@ func (m *TerraformManagerImpl) TerraformInitAndApply(directory string) (map[stri terraformInitializedFile := path.Join(directory, ".terraform-initialized") if !utils.FileExists(terraformInitializedFile) { log.Println("Initializing Terraform to spin up technique prerequisites") - err = terraform.Init(context.Background()) + err = terraform.Init(m.context) if err != nil { return nil, errors.New("unable to Initialize Terraform: " + err.Error()) } @@ -82,12 +88,12 @@ func (m *TerraformManagerImpl) TerraformInitAndApply(directory string) (map[stri } log.Println("Applying Terraform to spin up technique prerequisites") - err = terraform.Apply(context.Background(), tfexec.Refresh(false)) + err = terraform.Apply(m.context, tfexec.Refresh(false)) if err != nil { return nil, errors.New("unable to apply Terraform: " + err.Error()) } - rawOutputs, _ := terraform.Output(context.Background()) + rawOutputs, _ := terraform.Output(m.context) outputs := make(map[string]string, len(rawOutputs)) for outputName, outputRawValue := range rawOutputs { outputValue := string(outputRawValue.Value) @@ -104,5 +110,5 @@ func (m *TerraformManagerImpl) TerraformDestroy(directory string) error { return err } - return terraform.Destroy(context.Background()) + return terraform.Destroy(m.context) }