diff --git a/.gitignore b/.gitignore index 5749533e..b9f2c499 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # No exe or temp dir /temporal-features* /cloned-repo-* +/program-* # Build Java stuff build diff --git a/cmd/run.go b/cmd/run.go index e98ced13..79b6b380 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -10,6 +10,7 @@ import ( "io" "net" "os" + "os/exec" "os/signal" "path/filepath" "runtime" @@ -30,8 +31,9 @@ import ( ) const ( - summaryListenAddr = "127.0.0.1:0" - FeaturePassed = "PASSED" + proxyExecutableAuto = "auto" + freePortListenAddr = "127.0.0.1:0" + FeaturePassed = "PASSED" ) func runCmd() *cli.Command { @@ -57,14 +59,45 @@ type Summary []SummaryEntry // RunConfig is configuration for NewRunner. type RunConfig struct { PrepareConfig - Server string - Namespace string - ClientCertPath string - ClientKeyPath string - GenerateHistory bool - DisableHistoryCheck bool - RetainTempDir bool - SummaryURI string + Server string + DirectServer string + Namespace string + ClientCertPath string + ClientKeyPath string + GenerateHistory bool + DisableHistoryCheck bool + RetainTempDir bool + SummaryURI string + ProxyExecutablePath string + ProxyControlHostPort string + ProxyListenHostPort string +} + +func (config RunConfig) appendFlags(out []string) ([]string, error) { + out = append(out, "--server="+config.Server) + out = append(out, "--direct-server="+config.DirectServer) + out = append(out, "--namespace="+config.Namespace) + if config.ClientCertPath != "" { + clientCertPath, err := filepath.Abs(config.ClientCertPath) + if err != nil { + return nil, err + } + out = append(out, "--client-cert-path="+clientCertPath) + } + if config.ClientKeyPath != "" { + clientKeyPath, err := filepath.Abs(config.ClientKeyPath) + if err != nil { + return nil, err + } + out = append(out, "--client-key-path="+clientKeyPath) + } + if config.SummaryURI != "" { + out = append(out, "--summary-uri="+config.SummaryURI) + } + if config.ProxyControlHostPort != "" { + out = append(out, "--proxy-control-uri=http://"+config.ProxyControlHostPort+"/") + } + return out, nil } // dockerRunFlags are a subset of flags that apply when running in a docker container @@ -122,6 +155,22 @@ func (r *RunConfig) flags() []cli.Flag { Usage: "Relative directory already prepared. Cannot include version with this.", Destination: &r.DirName, }, + &cli.StringFlag{ + Name: "proxy-executable-path", + Usage: "Path of the temporal-features-test-proxy executable for connectivity/retry tests (optional)", + Value: proxyExecutableAuto, + Destination: &r.ProxyExecutablePath, + }, + &cli.StringFlag{ + Name: "proxy-control-hostport", + Usage: "explicit host:port for controlling the temporal-features-test-proxy (optional)", + Destination: &r.ProxyControlHostPort, + }, + &cli.StringFlag{ + Name: "proxy-listen-hostport", + Usage: "explicit host:port for using the temporal-features-test-proxy (optional)", + Destination: &r.ProxyListenHostPort, + }, }, r.dockerRunFlags()...) } @@ -133,6 +182,7 @@ type Runner struct { rootDir string createTime time.Time program sdkbuild.Program + proxy *exec.Cmd } // NewRunner creates a new runner for the given config. @@ -154,6 +204,22 @@ func (r *Runner) Run(ctx context.Context, patterns []string) error { return err } + var fn func(context.Context, *cmd.Run) error + switch r.config.Lang { + case "go": + fn = r.runGo + case "java": + fn = r.runJava + case "ts": + fn = r.runTypeScript + case "py": + fn = r.runPython + case "cs": + fn = r.runDotNet + default: + return fmt.Errorf("unrecognized language") + } + // Cannot generate history if a version isn't provided explicitly if r.config.GenerateHistory && r.config.Version == "" { return fmt.Errorf("must have explicit version to generate history") @@ -221,6 +287,49 @@ func (r *Runner) Run(ctx context.Context, patterns []string) error { return err } } + r.config.DirectServer = r.config.Server + + if r.config.ProxyExecutablePath == proxyExecutableAuto { + const suggestedPath = "./temporal-features-test-proxy" + fi, err := os.Stat(suggestedPath) + if err == nil && fi.Mode().IsRegular() && (fi.Mode()&0o111) != 0 { + r.config.ProxyExecutablePath = suggestedPath + } + } + if r.config.ProxyExecutablePath == proxyExecutableAuto { + const suggestedPath = "./temporal-features-test-proxy.exe" + fi, err := os.Stat(suggestedPath) + if err == nil && fi.Mode().IsRegular() { + r.config.ProxyExecutablePath = suggestedPath + } + } + if r.config.ProxyExecutablePath == proxyExecutableAuto { + r.config.ProxyExecutablePath = "" + } + if r.config.ProxyExecutablePath != "" { + if r.config.ProxyControlHostPort == "" { + r.config.ProxyControlHostPort, err = pickFreePort() + if err != nil { + return err + } + } + + if r.config.ProxyListenHostPort == "" { + r.config.ProxyListenHostPort, err = pickFreePort() + if err != nil { + return err + } + } + + err = r.startProxy(ctx) + if err != nil { + return err + } + r.log.Info("Started proxy", "Path", r.proxy.Path, "Args", r.proxy.Args) + r.config.Server = r.config.ProxyListenHostPort + } + + defer func() { _ = r.stopProxy() }() // Ensure any created temp dir is cleaned on ctrl-c or normal exit if r.config.DirName == "" && !r.config.RetainTempDir { @@ -234,70 +343,20 @@ func (r *Runner) Run(ctx context.Context, patterns []string) error { defer r.destroyTempDir() } - l, err := net.Listen("tcp", summaryListenAddr) + summaryListener, err := net.Listen("tcp", freePortListenAddr) if err != nil { return err } - defer l.Close() + defer summaryListener.Close() summaryChan := make(chan Summary) - go r.summaryServer(l, summaryChan) - r.config.SummaryURI = "tcp://" + l.Addr().String() + go r.summaryServer(summaryListener, summaryChan) + r.config.SummaryURI = "tcp://" + summaryListener.Addr().String() - err = nil - switch r.config.Lang { - case "go": - // If there's a version or prepared dir we run external, otherwise we run local - if r.config.Version != "" || r.config.DirName != "" { - if r.config.DirName != "" { - r.program, err = sdkbuild.GoProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) - } - if err == nil { - err = r.RunGoExternal(ctx, run) - } - } else { - err = cmd.NewRunner(cmd.RunConfig{ - Server: r.config.Server, - Namespace: r.config.Namespace, - ClientCertPath: r.config.ClientCertPath, - ClientKeyPath: r.config.ClientKeyPath, - SummaryURI: r.config.SummaryURI, - }).Run(ctx, run) - } - case "java": - if r.config.DirName != "" { - r.program, err = sdkbuild.JavaProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) - } - if err == nil { - err = r.RunJavaExternal(ctx, run) - } - case "ts": - if r.config.DirName != "" { - r.program, err = sdkbuild.TypeScriptProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) - } - if err == nil { - err = r.RunTypeScriptExternal(ctx, run) - } - case "py": - if r.config.DirName != "" { - r.program, err = sdkbuild.PythonProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) - } - if err == nil { - err = r.RunPythonExternal(ctx, run) - } - case "cs": - if r.config.DirName != "" { - r.program, err = sdkbuild.DotNetProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) - } - if err == nil { - err = r.RunDotNetExternal(ctx, run) - } - default: - err = fmt.Errorf("unrecognized language") - } + err = fn(ctx, run) if err != nil { return err } - l.Close() + summaryListener.Close() summary, ok := <-summaryChan if !ok { r.log.Debug("did not receive a test run summary - adopting legacy behavior of assuming no tests were skipped") @@ -305,7 +364,86 @@ func (r *Runner) Run(ctx context.Context, patterns []string) error { summary = append(summary, SummaryEntry{Name: feature.Dir, Outcome: FeaturePassed}) } } - return r.handleHistory(ctx, run, summary) + + err = r.handleHistory(ctx, run, summary) + if err != nil { + return err + } + + if r.proxy != nil { + err = r.stopProxy() + if err != nil { + return err + } + } + + return nil +} + +func (r *Runner) runGo(ctx context.Context, run *cmd.Run) error { + // If there's a version or prepared dir we run external, otherwise we run local + if r.config.Version == "" && r.config.DirName == "" { + return cmd.NewRunner(cmd.RunConfig{ + Server: r.config.Server, + Namespace: r.config.Namespace, + ClientCertPath: r.config.ClientCertPath, + ClientKeyPath: r.config.ClientKeyPath, + SummaryURI: r.config.SummaryURI, + }).Run(ctx, run) + } + + if r.config.DirName != "" { + var err error + r.program, err = sdkbuild.GoProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) + if err != nil { + return err + } + } + return r.RunGoExternal(ctx, run) +} + +func (r *Runner) runJava(ctx context.Context, run *cmd.Run) error { + if r.config.DirName != "" { + var err error + r.program, err = sdkbuild.JavaProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) + if err != nil { + return err + } + } + return r.RunJavaExternal(ctx, run) +} + +func (r *Runner) runTypeScript(ctx context.Context, run *cmd.Run) error { + if r.config.DirName != "" { + var err error + r.program, err = sdkbuild.TypeScriptProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) + if err != nil { + return err + } + } + return r.RunTypeScriptExternal(ctx, run) +} + +func (r *Runner) runPython(ctx context.Context, run *cmd.Run) error { + if r.config.DirName != "" { + var err error + r.program, err = sdkbuild.PythonProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) + if err != nil { + return err + } + } + return r.RunPythonExternal(ctx, run) +} + +func (r *Runner) runDotNet(ctx context.Context, run *cmd.Run) error { + if r.config.DirName != "" { + var err error + r.program, err = sdkbuild.DotNetProgramFromDir(filepath.Join(r.rootDir, r.config.DirName)) + if err != nil { + return err + } + } + return r.RunDotNetExternal(ctx, run) } func (r *Runner) handleHistory(ctx context.Context, run *cmd.Run, summary Summary) error { @@ -537,3 +675,58 @@ func (s Summary) Find(featureName string) (*SummaryEntry, bool) { } return nil, false } + +func (r *Runner) startProxy(ctx context.Context) error { + execPath, err := exec.LookPath(r.config.ProxyExecutablePath) + if err != nil { + return err + } + + r.proxy = exec.CommandContext( + ctx, + execPath, + "-control", r.config.ProxyControlHostPort, + "-listen", r.config.ProxyListenHostPort, + "-dial", r.config.Server, + ) + if err != nil { + return err + } + + r.proxy.Stderr = os.Stderr + err = r.proxy.Start() + if err != nil { + return err + } + return nil +} + +func (r *Runner) stopProxy() error { + if r.proxy == nil { + return nil + } + + if err := r.proxy.Process.Signal(os.Interrupt); err != nil { + return fmt.Errorf("failed to interrupt proxy subprocess: %w", err) + } + + if err := r.proxy.Wait(); err != nil { + return fmt.Errorf("proxy subprocess failed: %w", err) + } + + return nil +} + +func pickFreePort() (string, error) { + addr, err := net.ResolveTCPAddr("tcp", freePortListenAddr) + if err != nil { + return "", err + } + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return "", err + } + hostPort := l.Addr().String() + _ = l.Close() + return hostPort, nil +} diff --git a/cmd/run_dotnet.go b/cmd/run_dotnet.go index 05df1e40..c5176ad7 100644 --- a/cmd/run_dotnet.go +++ b/cmd/run_dotnet.go @@ -64,11 +64,15 @@ func (r *Runner) RunDotNetExternal(ctx context.Context, run *cmd.Run) error { } } - args := []string{"--server", r.config.Server, "--namespace", r.config.Namespace} - if r.config.ClientCertPath != "" { - args = append(args, "--client-cert-path", r.config.ClientCertPath, "--client-key-path", r.config.ClientKeyPath) + // Build args + args := make([]string, 0, 64) + args, err := r.config.appendFlags(args) + if err != nil { + return err } args = append(args, run.ToArgs()...) + + // Run cmd, err := r.program.NewCommand(ctx, args...) if err == nil { r.log.Debug("Running Go separately", "Args", cmd.Args) diff --git a/cmd/run_go.go b/cmd/run_go.go index d45f3264..03aaa857 100644 --- a/cmd/run_go.go +++ b/cmd/run_go.go @@ -57,14 +57,16 @@ func (r *Runner) RunGoExternal(ctx context.Context, run *cmd.Run) error { } } - args := append([]string{ - "run", - "--server", r.config.Server, - "--namespace", r.config.Namespace, - "--client-cert-path", r.config.ClientCertPath, - "--client-key-path", r.config.ClientKeyPath, - "--summary-uri", r.config.SummaryURI, - }, run.ToArgs()...) + // Build args + args := make([]string, 0, 64) + args = append(args, "run") + args, err := r.config.appendFlags(args) + if err != nil { + return err + } + args = append(args, run.ToArgs()...) + + // Run cmd, err := r.program.NewCommand(ctx, args...) if err == nil { r.log.Debug("Running Go separately", "Args", cmd.Args) diff --git a/cmd/run_java.go b/cmd/run_java.go index a535e966..b6eeea1b 100644 --- a/cmd/run_java.go +++ b/cmd/run_java.go @@ -3,7 +3,6 @@ package cmd import ( "context" "fmt" - "path/filepath" "github.com/temporalio/features/harness/go/cmd" "github.com/temporalio/features/sdkbuild" @@ -40,23 +39,10 @@ func (r *Runner) RunJavaExternal(ctx context.Context, run *cmd.Run) error { } // Build args - args := []string{"--server", r.config.Server, "--namespace", r.config.Namespace} - if r.config.ClientCertPath != "" { - clientCertPath, err := filepath.Abs(r.config.ClientCertPath) - if err != nil { - return err - } - args = append(args, "--client-cert-path", clientCertPath) - } - if r.config.ClientKeyPath != "" { - clientKeyPath, err := filepath.Abs(r.config.ClientKeyPath) - if err != nil { - return err - } - args = append(args, "--client-key-path", clientKeyPath) - } - if r.config.SummaryURI != "" { - args = append(args, "--summary-uri", r.config.SummaryURI) + args := make([]string, 0, 64) + args, err := r.config.appendFlags(args) + if err != nil { + return err } args = append(args, run.ToArgs()...) diff --git a/cmd/run_python.go b/cmd/run_python.go index 6dc1849f..81707c0a 100644 --- a/cmd/run_python.go +++ b/cmd/run_python.go @@ -60,20 +60,11 @@ func (r *Runner) RunPythonExternal(ctx context.Context, run *cmd.Run) error { } // Build args - args := []string{"harness.python.main", "--server", r.config.Server, "--namespace", r.config.Namespace} - if r.config.ClientCertPath != "" { - clientCertPath, err := filepath.Abs(r.config.ClientCertPath) - if err != nil { - return err - } - args = append(args, "--client-cert-path", clientCertPath) - } - if r.config.ClientKeyPath != "" { - clientKeyPath, err := filepath.Abs(r.config.ClientKeyPath) - if err != nil { - return err - } - args = append(args, "--client-key-path", clientKeyPath) + args := make([]string, 0, 64) + args = append(args, "harness.python.main") + args, err := r.config.appendFlags(args) + if err != nil { + return err } args = append(args, run.ToArgs()...) diff --git a/cmd/run_typescript.go b/cmd/run_typescript.go index 62446908..28d5f47b 100644 --- a/cmd/run_typescript.go +++ b/cmd/run_typescript.go @@ -58,26 +58,11 @@ func (r *Runner) RunTypeScriptExternal(ctx context.Context, run *cmd.Run) error } // Build args - args := []string{ - "./tslib/harness/ts/main.js", - "--server", - r.config.Server, - "--namespace", - r.config.Namespace, - } - if r.config.ClientCertPath != "" { - clientCertPath, err := filepath.Abs(r.config.ClientCertPath) - if err != nil { - return err - } - args = append(args, "--client-cert-path", clientCertPath) - } - if r.config.ClientKeyPath != "" { - clientKeyPath, err := filepath.Abs(r.config.ClientKeyPath) - if err != nil { - return err - } - args = append(args, "--client-key-path", clientKeyPath) + args := make([]string, 0, 64) + args = append(args, "./tslib/harness/ts/main.js") + args, err := r.config.appendFlags(args) + if err != nil { + return err } args = append(args, run.ToArgs()...) diff --git a/features/data_converter/binary_protobuf/feature.py b/features/data_converter/binary_protobuf/feature.py index c6397b77..25d73685 100644 --- a/features/data_converter/binary_protobuf/feature.py +++ b/features/data_converter/binary_protobuf/feature.py @@ -19,6 +19,7 @@ EXPECTED_RESULT = DataBlob(data=bytes.fromhex("deadbeef")) + # An echo workflow @workflow.defn class Workflow: diff --git a/features/data_converter/codec/feature.py b/features/data_converter/codec/feature.py index 0e8d4678..c17c86c6 100644 --- a/features/data_converter/codec/feature.py +++ b/features/data_converter/codec/feature.py @@ -21,6 +21,7 @@ CODEC_ENCODING = "my_encoding" + # An echo workflow @workflow.defn class Workflow: diff --git a/features/data_converter/json/feature.py b/features/data_converter/json/feature.py index 002cbd33..6b734e2e 100644 --- a/features/data_converter/json/feature.py +++ b/features/data_converter/json/feature.py @@ -15,6 +15,7 @@ EXPECTED_RESULT: Result = {"spec": True} + # An echo workflow @workflow.defn class Workflow: diff --git a/features/data_converter/json_protobuf/feature.py b/features/data_converter/json_protobuf/feature.py index cbfb060a..d2d28e85 100644 --- a/features/data_converter/json_protobuf/feature.py +++ b/features/data_converter/json_protobuf/feature.py @@ -13,6 +13,7 @@ EXPECTED_RESULT = DataBlob(data=bytes.fromhex("deadbeef")) JSONP_decoder = JSONProtoPayloadConverter() + # An echo workflow @workflow.defn class Workflow: diff --git a/features/features.go b/features/features.go index 789c38e4..1dcb0627 100644 --- a/features/features.go +++ b/features/features.go @@ -24,6 +24,9 @@ import ( data_converter_json_protobuf "github.com/temporalio/features/features/data_converter/json_protobuf" eager_activity_non_remote_activities_worker "github.com/temporalio/features/features/eager_activity/non_remote_activities_worker" eager_workflow_successful_start "github.com/temporalio/features/features/eager_workflow/successful_start" + grpc_retry_server_frozen_for_initiator "github.com/temporalio/features/features/grpc_retry/server_frozen_for_initiator" + grpc_retry_server_restarted_for_initiator "github.com/temporalio/features/features/grpc_retry/server_restarted_for_initiator" + grpc_retry_server_unavailable_for_initiator "github.com/temporalio/features/features/grpc_retry/server_unavailable_for_initiator" query_successful_query "github.com/temporalio/features/features/query/successful_query" query_timeout_due_to_no_active_workers "github.com/temporalio/features/features/query/timeout_due_to_no_active_workers" query_unexpected_arguments "github.com/temporalio/features/features/query/unexpected_arguments" @@ -67,15 +70,18 @@ func init() { child_workflow_result.Feature, child_workflow_signal.Feature, continue_as_new_continue_as_same.Feature, - data_converter_binary_protobuf.Feature, data_converter_binary.Feature, + data_converter_binary_protobuf.Feature, data_converter_codec.Feature, data_converter_empty.Feature, data_converter_failure.Feature, - data_converter_json_protobuf.Feature, data_converter_json.Feature, + data_converter_json_protobuf.Feature, eager_activity_non_remote_activities_worker.Feature, eager_workflow_successful_start.Feature, + grpc_retry_server_frozen_for_initiator.Feature, + grpc_retry_server_restarted_for_initiator.Feature, + grpc_retry_server_unavailable_for_initiator.Feature, query_successful_query.Feature, query_timeout_due_to_no_active_workers.Feature, query_unexpected_arguments.Feature, @@ -92,8 +98,8 @@ func init() { update_activities.Feature, update_async_accepted.Feature, update_basic.Feature, - update_deduplication.Feature, update_client_interceptor.Feature, + update_deduplication.Feature, update_non_durable_reject.Feature, update_self.Feature, update_task_failure.Feature, diff --git a/features/grpc_retry/server_frozen_for_initiator/feature.go b/features/grpc_retry/server_frozen_for_initiator/feature.go new file mode 100644 index 00000000..859e63bc --- /dev/null +++ b/features/grpc_retry/server_frozen_for_initiator/feature.go @@ -0,0 +1,38 @@ +package server_frozen_for_initiator + +import ( + "context" + "sync" + "time" + + "github.com/temporalio/features/harness/go/harness" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" +) + +var Feature = harness.Feature{ + Workflows: Workflow, + Execute: func(ctx context.Context, runner *harness.Runner) (client.WorkflowRun, error) { + var wg sync.WaitGroup + defer wg.Wait() + if err := runner.ProxyFreezeAndThaw(ctx, &wg, 1*time.Second); err != nil { + return nil, err + } + + opts := client.StartWorkflowOptions{ + TaskQueue: runner.TaskQueue, + WorkflowExecutionTimeout: 1 * time.Minute, + RetryPolicy: &temporal.RetryPolicy{ + InitialInterval: 1 * time.Millisecond, + MaximumInterval: 100 * time.Millisecond, + BackoffCoefficient: 2.0, + }, + } + return runner.Client.ExecuteWorkflow(ctx, opts, Workflow) + }, +} + +func Workflow(ctx workflow.Context) (string, error) { + return "OK", nil +} diff --git a/features/grpc_retry/server_frozen_for_initiator/feature.java b/features/grpc_retry/server_frozen_for_initiator/feature.java new file mode 100644 index 00000000..56fc0a34 --- /dev/null +++ b/features/grpc_retry/server_frozen_for_initiator/feature.java @@ -0,0 +1,31 @@ +package grpc_retry.server_frozen_for_initiator; + +import io.temporal.client.WorkflowOptions; +import io.temporal.common.RetryOptions; +import io.temporal.sdkfeatures.Feature; +import io.temporal.sdkfeatures.Run; +import io.temporal.sdkfeatures.Runner; +import io.temporal.sdkfeatures.SimpleWorkflow; +import java.time.Duration; + +public interface feature extends Feature, SimpleWorkflow { + class Impl implements feature { + @Override + public Run execute(Runner runner) throws Exception { + return runner.proxyFreezeAndThaw(Duration.ofSeconds(1), () -> feature.super.execute(runner)); + } + + @Override + public void workflowOptions(WorkflowOptions.Builder builder) { + builder.setRetryOptions( + RetryOptions.newBuilder() + .setInitialInterval(Duration.ofMillis(1)) + .setMaximumInterval(Duration.ofMillis(100)) + .setBackoffCoefficient(2.0) + .validateBuildWithDefaults()); + } + + @Override + public void workflow() {} + } +} diff --git a/features/grpc_retry/server_restarted_for_initiator/feature.go b/features/grpc_retry/server_restarted_for_initiator/feature.go new file mode 100644 index 00000000..042e1c25 --- /dev/null +++ b/features/grpc_retry/server_restarted_for_initiator/feature.go @@ -0,0 +1,35 @@ +package server_restarted_for_initiator + +import ( + "context" + "time" + + "github.com/temporalio/features/harness/go/harness" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" +) + +var Feature = harness.Feature{ + Workflows: Workflow, + Execute: func(ctx context.Context, runner *harness.Runner) (client.WorkflowRun, error) { + if err := runner.ProxyRestart(ctx, 2*time.Second, true); err != nil { + return nil, err + } + + opts := client.StartWorkflowOptions{ + TaskQueue: runner.TaskQueue, + WorkflowExecutionTimeout: 1 * time.Minute, + RetryPolicy: &temporal.RetryPolicy{ + InitialInterval: 1 * time.Millisecond, + MaximumInterval: 100 * time.Millisecond, + BackoffCoefficient: 2.0, + }, + } + return runner.Client.ExecuteWorkflow(ctx, opts, Workflow) + }, +} + +func Workflow(ctx workflow.Context) (string, error) { + return "OK", nil +} diff --git a/features/grpc_retry/server_restarted_for_initiator/feature.java b/features/grpc_retry/server_restarted_for_initiator/feature.java new file mode 100644 index 00000000..29ac049c --- /dev/null +++ b/features/grpc_retry/server_restarted_for_initiator/feature.java @@ -0,0 +1,34 @@ +package grpc_retry.server_restarted_for_initiator; + +import io.temporal.activity.ActivityInterface; +import io.temporal.client.WorkflowOptions; +import io.temporal.common.RetryOptions; +import io.temporal.sdkfeatures.Feature; +import io.temporal.sdkfeatures.Run; +import io.temporal.sdkfeatures.Runner; +import io.temporal.sdkfeatures.SimpleWorkflow; +import java.time.Duration; + +@ActivityInterface +public interface feature extends Feature, SimpleWorkflow { + class Impl implements feature { + @Override + public Run execute(Runner runner) throws Exception { + runner.proxyRestart(Duration.ofSeconds(1), true); + return feature.super.execute(runner); + } + + @Override + public void workflowOptions(WorkflowOptions.Builder builder) { + builder.setRetryOptions( + RetryOptions.newBuilder() + .setInitialInterval(Duration.ofMillis(1)) + .setMaximumInterval(Duration.ofMillis(100)) + .setBackoffCoefficient(2.0) + .validateBuildWithDefaults()); + } + + @Override + public void workflow() {} + } +} diff --git a/features/grpc_retry/server_unavailable_for_initiator/feature.go b/features/grpc_retry/server_unavailable_for_initiator/feature.go new file mode 100644 index 00000000..6284dd17 --- /dev/null +++ b/features/grpc_retry/server_unavailable_for_initiator/feature.go @@ -0,0 +1,38 @@ +package server_unavailable_for_initiator + +import ( + "context" + "sync" + "time" + + "github.com/temporalio/features/harness/go/harness" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" +) + +var Feature = harness.Feature{ + Workflows: Workflow, + Execute: func(ctx context.Context, runner *harness.Runner) (client.WorkflowRun, error) { + var wg sync.WaitGroup + defer wg.Wait() + if err := runner.ProxyRejectAndAccept(ctx, &wg, 1*time.Second); err != nil { + return nil, err + } + + opts := client.StartWorkflowOptions{ + TaskQueue: runner.TaskQueue, + WorkflowExecutionTimeout: 1 * time.Minute, + RetryPolicy: &temporal.RetryPolicy{ + InitialInterval: 1 * time.Millisecond, + MaximumInterval: 100 * time.Millisecond, + BackoffCoefficient: 2.0, + }, + } + return runner.Client.ExecuteWorkflow(ctx, opts, Workflow) + }, +} + +func Workflow(ctx workflow.Context) (string, error) { + return "OK", nil +} diff --git a/features/grpc_retry/server_unavailable_for_initiator/feature.java b/features/grpc_retry/server_unavailable_for_initiator/feature.java new file mode 100644 index 00000000..b37362b6 --- /dev/null +++ b/features/grpc_retry/server_unavailable_for_initiator/feature.java @@ -0,0 +1,34 @@ +package grpc_retry.server_unavailable_for_initiator; + +import io.temporal.activity.ActivityInterface; +import io.temporal.client.WorkflowOptions; +import io.temporal.common.RetryOptions; +import io.temporal.sdkfeatures.Feature; +import io.temporal.sdkfeatures.Run; +import io.temporal.sdkfeatures.Runner; +import io.temporal.sdkfeatures.SimpleWorkflow; +import java.time.Duration; + +@ActivityInterface +public interface feature extends Feature, SimpleWorkflow { + class Impl implements feature { + @Override + public Run execute(Runner runner) throws Exception { + return runner.proxyRejectAndAccept( + Duration.ofSeconds(1), () -> feature.super.execute(runner)); + } + + @Override + public void workflowOptions(WorkflowOptions.Builder builder) { + builder.setRetryOptions( + RetryOptions.newBuilder() + .setInitialInterval(Duration.ofMillis(1)) + .setMaximumInterval(Duration.ofMillis(100)) + .setBackoffCoefficient(2.0) + .validateBuildWithDefaults()); + } + + @Override + public void workflow() {} + } +} diff --git a/harness/dotnet/Temporalio.Features.Harness/App.cs b/harness/dotnet/Temporalio.Features.Harness/App.cs index ac1d91cf..618b9585 100644 --- a/harness/dotnet/Temporalio.Features.Harness/App.cs +++ b/harness/dotnet/Temporalio.Features.Harness/App.cs @@ -14,6 +14,10 @@ public static class App description: "The host:port of the server") { IsRequired = true }; + private static readonly Option directServerOption = new( + name: "--direct-server", + description: "The host:port of the server, bypassing the temporal-features-test-proxy"); + private static readonly Option namespaceOption = new( name: "--namespace", description: "The namespace to use") @@ -27,6 +31,14 @@ public static class App name: "--client-key-path", description: "Path to a client key for TLS"); + private static readonly Option summaryUriOption = new( + name: "--summary-uri", + description: "Where to stream the test summary JSONL (not implemented)"); + + private static readonly Option proxyControlUriOption = new( + name: "--proxy-control-uri", + description: "URI for simulating network outages with temporal-features-test-proxy"); + private static readonly Argument> featuresArgument = new( name: "features", parse: result => result.Tokens.Select(token => @@ -53,9 +65,11 @@ private static Command CreateCommand() { var cmd = new RootCommand(".NET features harness"); cmd.AddOption(serverOption); + cmd.AddOption(directServerOption); cmd.AddOption(namespaceOption); cmd.AddOption(clientCertPathOption); cmd.AddOption(clientKeyPathOption); + cmd.AddOption(proxyControlUriOption); cmd.AddArgument(featuresArgument); cmd.SetHandler(RunCommandAsync); return cmd; @@ -120,4 +134,4 @@ private static async Task RunCommandAsync(InvocationContext ctx) logger.LogInformation("All features passed"); } -} \ No newline at end of file +} diff --git a/harness/go/cmd/run.go b/harness/go/cmd/run.go index 47cf122a..fa218259 100644 --- a/harness/go/cmd/run.go +++ b/harness/go/cmd/run.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "go.temporal.io/sdk/client" "io" "net" "net/url" @@ -14,6 +13,7 @@ import ( "github.com/temporalio/features/harness/go/harness" "github.com/urfave/cli/v2" + "go.temporal.io/sdk/client" "go.temporal.io/sdk/log" "go.uber.org/zap" ) @@ -86,11 +86,13 @@ type RunFeatureConfigGo struct { // RunConfig is configuration for NewRunner. type RunConfig struct { - Server string - Namespace string - ClientCertPath string - ClientKeyPath string - SummaryURI string + Server string + DirectServer string + Namespace string + ClientCertPath string + ClientKeyPath string + SummaryURI string + ProxyControlURI string } func (r *RunConfig) flags() []cli.Flag { @@ -100,6 +102,11 @@ func (r *RunConfig) flags() []cli.Flag { Usage: "The host:port of the server (default is to create ephemeral in-memory server)", Destination: &r.Server, }, + &cli.StringFlag{ + Name: "direct-server", + Usage: "The host:port of the server, bypassing the temporal-features-test-proxy", + Destination: &r.DirectServer, + }, &cli.StringFlag{ Name: "namespace", Usage: "The namespace to use (default is random)", @@ -120,6 +127,11 @@ func (r *RunConfig) flags() []cli.Flag { Usage: "where to stream the test summary JSONL", Destination: &r.SummaryURI, }, + &cli.StringFlag{ + Name: "proxy-control-uri", + Usage: "how to simulate network outages via temporal-features-test-proxy (optional)", + Destination: &r.ProxyControlURI, + }, } } @@ -164,11 +176,21 @@ func (r *Runner) Run(ctx context.Context, run *Run) error { if len(run.Features) == 0 { return fmt.Errorf("no features to run") } + summary, err := openSummary(r.config.SummaryURI) if err != nil { return err } defer summary.Close() + + var proxyControlURL *url.URL + if r.config.ProxyControlURI != "" { + proxyControlURL, err = url.Parse(r.config.ProxyControlURI) + if err != nil { + return err + } + } + var failureCount int failureSummary := "" allFeatures := harness.RegisteredFeatures() @@ -210,12 +232,14 @@ func (r *Runner) Run(ctx context.Context, run *Run) error { } runnerConfig := harness.RunnerConfig{ - ServerHostPort: r.config.Server, - Namespace: r.config.Namespace, - ClientCertPath: r.config.ClientCertPath, - ClientKeyPath: r.config.ClientKeyPath, - TaskQueue: runFeature.TaskQueue, - Log: r.log, + ServerHostPort: r.config.Server, + DirectHostPort: r.config.DirectServer, + Namespace: r.config.Namespace, + ClientCertPath: r.config.ClientCertPath, + ClientKeyPath: r.config.ClientKeyPath, + ProxyControlURL: proxyControlURL, + TaskQueue: runFeature.TaskQueue, + Log: r.log, } err := r.runFeature(ctx, runnerConfig, feature) diff --git a/harness/go/harness/feature.go b/harness/go/harness/feature.go index 14893dd1..a111ab3d 100644 --- a/harness/go/harness/feature.go +++ b/harness/go/harness/feature.go @@ -47,6 +47,12 @@ type Feature struct { // DisableWorkflowPanicPolicyOverride field to true. WorkerOptions worker.Options + // BeforeDial provides a hook that will be called just before calling client.Dial. + BeforeDial func(runner *Runner) error + + // BeforeWorkerStart provides a hook that will be called just before calling Worker.Start. + BeforeWorkerStart func(runner *Runner) error + // Can modify the workflow options that are used by the default executor. Some values such as // task queue and workflow execution timeout, are set by default (but may be overridden by this // mutator). @@ -72,6 +78,11 @@ type Feature struct { // If non-empty, this feature will be skipped without checking any other // values. SkipReason string + + // WorkerUsesProxy indicates if the client used to run the worker + // should be one that goes through the temporal-features-test-proxy + // instead of talking directly to the server. + WorkerUsesProxy bool } type WorkflowWithOptions struct { diff --git a/harness/go/harness/runner.go b/harness/go/harness/runner.go index be3f5445..5983195f 100644 --- a/harness/go/harness/runner.go +++ b/harness/go/harness/runner.go @@ -4,8 +4,13 @@ import ( "context" "errors" "fmt" + "io" + "net/http" + "net/url" + "path" "path/filepath" "reflect" + "strconv" "strings" "sync" "time" @@ -31,10 +36,11 @@ type skipFeatureError struct { // Runner represents a runner that can run a feature. type Runner struct { RunnerConfig - Client client.Client - Worker worker.Worker - Feature *PreparedFeature - CreateTime time.Time + Client client.Client + DirectClient client.Client + Worker worker.Worker + Feature *PreparedFeature + CreateTime time.Time Assert *assert.Assertions LastAssertErr error @@ -43,12 +49,14 @@ type Runner struct { // RunnerConfig is configuration for NewRunner. type RunnerConfig struct { - ServerHostPort string - Namespace string - TaskQueue string - ClientCertPath string - ClientKeyPath string - Log log.Logger + ServerHostPort string + DirectHostPort string + Namespace string + TaskQueue string + ClientCertPath string + ClientKeyPath string + ProxyControlURL *url.URL + Log log.Logger } // NewRunner creates a new runner for the given config and feature. @@ -89,10 +97,27 @@ func NewRunner(config RunnerConfig, feature *PreparedFeature) (*Runner, error) { } r.Feature.ClientOptions.ConnectionOptions.TLS = tlsCfg + if r.Feature.BeforeDial != nil { + if err = r.Feature.BeforeDial(r); err != nil { + return nil, err + } + } + if r.Client, err = client.Dial(r.Feature.ClientOptions); err != nil { return nil, fmt.Errorf("failed creating client: %w", err) } + if r.DirectHostPort == "" || r.DirectHostPort == r.ServerHostPort { + r.DirectClient = r.Client + } else { + savedValue := r.Feature.ClientOptions.HostPort + r.Feature.ClientOptions.HostPort = r.DirectHostPort + if r.DirectClient, err = client.Dial(r.Feature.ClientOptions); err != nil { + return nil, fmt.Errorf("failed creating client: %w", err) + } + r.Feature.ClientOptions.HostPort = savedValue + } + // Create worker r.CreateTime = time.Now() if !r.Feature.DisableWorkflowPanicPolicyOverride { @@ -350,6 +375,105 @@ func (r *Runner) DoUntilEventually( } } +func (r *Runner) ProxyRestart(ctx context.Context, sleep time.Duration, forceful bool) error { + sleepStr := sleep.String() + forcefulStr := strconv.FormatBool(forceful) + return r.proxySendCommand(ctx, "restart", "sleep", sleepStr, "forceful", forcefulStr) +} + +func (r *Runner) ProxyReject(ctx context.Context) error { + return r.proxySendCommand(ctx, "reject") +} + +func (r *Runner) ProxyAccept(ctx context.Context) error { + return r.proxySendCommand(ctx, "accept") +} + +func (r *Runner) ProxyFreeze(ctx context.Context) error { + return r.proxySendCommand(ctx, "freeze") +} + +func (r *Runner) ProxyThaw(ctx context.Context) error { + return r.proxySendCommand(ctx, "thaw") +} + +func (r *Runner) ProxyRejectAndAccept(ctx context.Context, wg *sync.WaitGroup, sleep time.Duration) error { + if err := r.ProxyReject(ctx); err != nil { + return err + } + + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(sleep) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + _ = r.ProxyAccept(ctx) + }() + return nil +} + +func (r *Runner) ProxyFreezeAndThaw(ctx context.Context, wg *sync.WaitGroup, sleep time.Duration) error { + if err := r.ProxyFreeze(ctx); err != nil { + return err + } + + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(sleep) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + _ = r.ProxyThaw(ctx) + }() + return nil +} + +func (r *Runner) proxySendCommand(ctx context.Context, command string, args ...string) error { + if r.ProxyControlURL == nil { + return r.Skip("temporal-features-test-proxy is required for this test") + } + + var u url.URL + u = *r.ProxyControlURL + u.Path = path.Join(u.Path, command) + if numArgs := len(args); numArgs != 0 { + q := make(url.Values, numArgs/2) + for i := 0; i < numArgs; i += 2 { + key, value := args[i], args[i+1] + q.Add(key, value) + } + u.RawQuery = q.Encode() + } + + reqMethod := http.MethodPost + reqURL := u.String() + reqName := fmt.Sprintf("%s %s", reqMethod, reqURL) + req, err := http.NewRequestWithContext(ctx, reqMethod, reqURL, nil) + if err != nil { + return fmt.Errorf("failed to create net/http.Request %q: %w", reqName, err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to perform HTTP request %q: %w", reqName, err) + } + + _, err = io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + return fmt.Errorf("failed to read body for HTTP %03d response to request %q: %w", resp.StatusCode, reqName, err) + } + + if resp.StatusCode >= 400 { + return fmt.Errorf("HTTP %03d response to request %q", resp.StatusCode, reqName) + } + + return nil +} + // Close closes this runner. func (r *Runner) Close() { if r.Worker != nil { @@ -395,7 +519,12 @@ func (r *Runner) StartWorker() error { if r.Worker != nil { return errors.New("worker is currently running, cannot start a new one") } - r.Worker = worker.New(r.Client, r.RunnerConfig.TaskQueue, r.Feature.WorkerOptions) + + c := r.DirectClient + if r.Feature.WorkerUsesProxy { + c = r.Client + } + r.Worker = worker.New(c, r.RunnerConfig.TaskQueue, r.Feature.WorkerOptions) // Register the workflows and activities for _, workflow := range r.Feature.Workflows { @@ -417,6 +546,12 @@ func (r *Runner) StartWorker() error { } } + if r.Feature.BeforeWorkerStart != nil { + if err := r.Feature.BeforeWorkerStart(r); err != nil { + return err + } + } + // Start the worker if err := r.Worker.Start(); err != nil { return fmt.Errorf("failed starting worker: %w", err) diff --git a/harness/java/io/temporal/sdkfeatures/Feature.java b/harness/java/io/temporal/sdkfeatures/Feature.java index e9eb8f90..dc9ddc0c 100644 --- a/harness/java/io/temporal/sdkfeatures/Feature.java +++ b/harness/java/io/temporal/sdkfeatures/Feature.java @@ -12,11 +12,10 @@ public interface Feature { - @SuppressWarnings("unchecked") default T activities(Class activityIface, Consumer builderFunc) { var builder = ActivityOptions.newBuilder(); builderFunc.accept(builder); - return (T) Workflow.newActivityStub(activityIface, builder.build()); + return Workflow.newActivityStub(activityIface, builder.build()); } default void workflowServiceOptions(WorkflowServiceStubsOptions.Builder builder) {} @@ -29,6 +28,14 @@ default void workerOptions(WorkerOptions.Builder builder) {} default void workflowOptions(WorkflowOptions.Builder builder) {} + default boolean workerUsesProxy() { + return false; + } + + default boolean initiatorUsesProxy() { + return true; + } + default Run execute(Runner runner) throws Exception { return runner.executeSingleParameterlessWorkflow(); } diff --git a/harness/java/io/temporal/sdkfeatures/Main.java b/harness/java/io/temporal/sdkfeatures/Main.java index 794f202a..64d39d9f 100644 --- a/harness/java/io/temporal/sdkfeatures/Main.java +++ b/harness/java/io/temporal/sdkfeatures/Main.java @@ -10,6 +10,8 @@ import java.io.*; import java.net.Socket; import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; @@ -57,7 +59,8 @@ BufferedWriter createSummaryServerWriter() { switch (uri.getScheme()) { case "tcp": Socket socket = new Socket(uri.getHost(), uri.getPort()); - return new BufferedWriter(new OutputStreamWriter(socket.getOutputStream(), "UTF-8")); + return new BufferedWriter( + new OutputStreamWriter(socket.getOutputStream(), StandardCharsets.UTF_8)); case "file": FileWriter fileWriter = new FileWriter(uri.getPath(), true); return new BufferedWriter(fileWriter); @@ -74,9 +77,19 @@ BufferedWriter createSummaryServerWriter() { @Option(names = "--summary-uri", description = "The URL of the summary server", required = true) private String summaryUri; + @Option( + names = "--proxy-control-uri", + description = "The URL of temporal-features-test-proxy (optional)") + private String proxyControlUri; + @Option(names = "--server", description = "The host:port of the server", required = true) private String server; + @Option( + names = "--direct-server", + description = "The host:port of the server, bypassing the temporal-features-test-proxy") + private String directServer; + @Option(names = "--namespace", description = "The namespace to use", required = true) private String namespace; @@ -110,6 +123,19 @@ public void run() { throw new RuntimeException("Client cert path must be specified since key path is"); } + // Parse proxyControlUri if present + URI proxyControl = null; + if (proxyControlUri != null && !proxyControlUri.isEmpty()) { + try { + proxyControl = new URI(proxyControlUri); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + final String processedDirectServer = + (directServer != null && !directServer.isEmpty()) ? directServer : server; + try (BufferedWriter writer = createSummaryServerWriter()) { ObjectMapper mapper = new ObjectMapper(); @@ -134,8 +160,10 @@ public void run() { log.info("Running feature {}", feature.dir); var config = new Runner.Config(); config.serverHostPort = server; + config.directHostPort = processedDirectServer; config.namespace = namespace; config.sslContext = sslContext; + config.proxyControl = proxyControl; config.taskQueue = pieces[1]; Outcome outcome = Outcome.PASSED; String message = ""; diff --git a/harness/java/io/temporal/sdkfeatures/PreparedFeature.java b/harness/java/io/temporal/sdkfeatures/PreparedFeature.java index 0bf2c1af..fb500e61 100644 --- a/harness/java/io/temporal/sdkfeatures/PreparedFeature.java +++ b/harness/java/io/temporal/sdkfeatures/PreparedFeature.java @@ -7,8 +7,8 @@ public class PreparedFeature { static PreparedFeature[] ALL = PreparedFeature.prepareFeatures( activity.basic_no_workflow_timeout.feature.Impl.class, - activity.retry_on_error.feature.Impl.class, activity.cancel_try_cancel.feature.Impl.class, + activity.retry_on_error.feature.Impl.class, child_workflow.result.feature.Impl.class, child_workflow.signal.feature.Impl.class, continue_as_new.continue_as_same.feature.Impl.class, @@ -19,6 +19,9 @@ public class PreparedFeature { data_converter.json.feature.Impl.class, data_converter.json_protobuf.feature.Impl.class, eager_activity.non_remote_activities_worker.feature.Impl.class, + grpc_retry.server_frozen_for_initiator.feature.Impl.class, + grpc_retry.server_restarted_for_initiator.feature.Impl.class, + grpc_retry.server_unavailable_for_initiator.feature.Impl.class, query.successful_query.feature.Impl.class, query.timeout_due_to_no_active_workers.feature.Impl.class, query.unexpected_arguments.feature.Impl.class, @@ -32,12 +35,12 @@ public class PreparedFeature { signal.external.feature.Impl.class, update.activities.feature.Impl.class, update.async_accepted.feature.Impl.class, - update.deduplication.feature.Impl.class, update.client_interceptor.feature.Impl.class, + update.deduplication.feature.Impl.class, update.non_durable_reject.feature.Impl.class, update.task_failure.feature.Impl.class, - update.worker_restart.feature.Impl.class, update.validation_replay.feature.Impl.class, + update.worker_restart.feature.Impl.class, update.self.feature.Impl.class); @SafeVarargs diff --git a/harness/java/io/temporal/sdkfeatures/Runner.java b/harness/java/io/temporal/sdkfeatures/Runner.java index a5b398f7..1efb64e2 100644 --- a/harness/java/io/temporal/sdkfeatures/Runner.java +++ b/harness/java/io/temporal/sdkfeatures/Runner.java @@ -12,6 +12,7 @@ import io.temporal.api.common.v1.Payload; import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.history.v1.History; +import io.temporal.api.history.v1.HistoryEvent; import io.temporal.api.workflow.v1.WorkflowExecutionInfo; import io.temporal.api.workflowservice.v1.DescribeWorkflowExecutionRequest; import io.temporal.client.*; @@ -24,6 +25,8 @@ import io.temporal.worker.WorkerFactoryOptions; import io.temporal.worker.WorkerOptions; import java.io.Closeable; +import java.io.IOException; +import java.net.*; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.*; @@ -38,22 +41,27 @@ public class Runner implements Closeable { public static class Config { public String serverHostPort; + public String directHostPort; public String namespace; public String taskQueue; public Scope metricsScope = new NoopScope(); public SslContext sslContext; + public URI proxyControl; } public final Config config; public final PreparedFeature featureInfo; public final Feature feature; public final WorkflowServiceStubs service; + public final WorkflowServiceStubs directService; public final WorkflowClient client; + public final WorkflowClient directClient; private WorkerFactory workerFactory; private Worker worker; Runner(Config config, PreparedFeature featureInfo) { Objects.requireNonNull(config.serverHostPort); + Objects.requireNonNull(config.directHostPort); Objects.requireNonNull(config.namespace); Objects.requireNonNull(config.taskQueue); this.config = config; @@ -61,28 +69,170 @@ public static class Config { feature = featureInfo.newInstance(); // Build service - var serviceBuild = + final var serviceBuild = WorkflowServiceStubsOptions.newBuilder() .setTarget(config.serverHostPort) .setSslContext(config.sslContext) .setMetricsScope(config.metricsScope); feature.workflowServiceOptions(serviceBuild); - service = WorkflowServiceStubs.newServiceStubs(serviceBuild.build()); + final var serviceOptions = serviceBuild.build(); + final var directServiceOptions = serviceBuild.setTarget(config.directHostPort).build(); + + service = WorkflowServiceStubs.newServiceStubs(serviceOptions); // Shutdown service on failure try { - // Build client - var clientBuild = WorkflowClientOptions.newBuilder().setNamespace(config.namespace); - feature.workflowClientOptions(clientBuild); - client = WorkflowClient.newInstance(service, clientBuild.build()); - - // Build worker - restartWorker(); + directService = WorkflowServiceStubs.newServiceStubs(directServiceOptions); + try { + // Build client + var clientBuild = WorkflowClientOptions.newBuilder().setNamespace(config.namespace); + feature.workflowClientOptions(clientBuild); + client = WorkflowClient.newInstance(service, clientBuild.build()); + directClient = WorkflowClient.newInstance(directService, clientBuild.build()); + + // Build worker + restartWorker(); + } catch (Throwable e) { + try { + directService.shutdownNow(); + } catch (Throwable ignored) { + } + throw e; + } } catch (Throwable e) { - service.shutdownNow(); + try { + service.shutdownNow(); + } catch (Throwable ignored) { + } throw e; } } + public WorkflowClient workerClient() { + if (feature.workerUsesProxy()) { + return client; + } else { + return directClient; + } + } + + public WorkflowClient initiatorClient() { + if (feature.initiatorUsesProxy()) { + return client; + } else { + return directClient; + } + } + + public WorkflowServiceStubs initiatorService() { + if (feature.initiatorUsesProxy()) { + return service; + } else { + return directService; + } + } + + public void proxyReject() throws IOException { + proxySendCommand("reject"); + } + + public void proxyAccept() throws IOException { + proxySendCommand("accept"); + } + + public void proxyFreeze() throws IOException { + proxySendCommand("freeze"); + } + + public void proxyThaw() throws IOException { + proxySendCommand("thaw"); + } + + public void proxyRestart(Duration sleep, boolean forceful) throws IOException { + final var sleepStr = sleep.toMillis() + "ms"; + final var forcefulStr = forceful ? "true" : "false"; + proxySendCommand("restart", "sleep", sleepStr, "forceful", forcefulStr); + } + + public T proxyRejectAndAccept( + Duration sleep, CheckedCallable runnable) throws E, IOException { + return proxyFirstAndSecond(sleep, runnable, this::proxyReject, this::proxyAccept); + } + + public T proxyFreezeAndThaw( + Duration sleep, CheckedCallable callable) throws E, IOException { + return proxyFirstAndSecond(sleep, callable, this::proxyFreeze, this::proxyThaw); + } + + private T proxyFirstAndSecond( + Duration sleep, + CheckedCallable callable, + CheckedRunnable first, + CheckedRunnable second) + throws E, IOException { + first.run(); + final var thread = + new Thread( + () -> { + try { + Thread.sleep(sleep.toMillis()); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + } + try { + second.run(); + } catch (IOException ignored) { + } + }); + thread.start(); + try { + return callable.call(); + } finally { + try { + thread.join(); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + } + } + } + + public void proxySendCommand(String method, String... args) throws IOException { + if (config.proxyControl == null) { + skip("temporal-features-test-proxy is required for this test"); + } + + final StringBuilder sb = new StringBuilder(); + sb.append('/'); + sb.append(method); + if (args != null && args.length != 0) { + char separator = '?'; + for (int i = 0; i < args.length; i += 2) { + String key = args[i]; + String value = args[i + 1]; + sb.append(separator); + sb.append(URLEncoder.encode(key, StandardCharsets.UTF_8)); + sb.append('='); + sb.append(URLEncoder.encode(value, StandardCharsets.UTF_8)); + separator = '&'; + } + } + final URI uri = config.proxyControl.resolve(sb.toString()); + log.info("proxySendCommand: {}", uri); + var connection = (HttpURLConnection) uri.toURL().openConnection(); + connection.setConnectTimeout(1000); + connection.setReadTimeout(10000); + connection.setInstanceFollowRedirects(false); + connection.setRequestMethod("POST"); + try { + connection.connect(); + final int code = connection.getResponseCode(); + if (code >= 400) { + throw new IOException("proxy command failed with HTTP code " + code); + } + } finally { + connection.disconnect(); + } + } + /** * Instantiates a new worker, replacing the existing worker and workerFactory. You should shut * down the worker factory before calling this. @@ -90,7 +240,7 @@ public static class Config { public void restartWorker() { var factoryBuild = WorkerFactoryOptions.newBuilder(); feature.workerFactoryOptions(factoryBuild); - this.workerFactory = WorkerFactory.newInstance(client, factoryBuild.build()); + this.workerFactory = WorkerFactory.newInstance(workerClient(), factoryBuild.build()); var workerBuild = WorkerOptions.newBuilder(); feature.workerOptions(workerBuild); this.worker = workerFactory.newWorker(config.taskQueue, workerBuild.build()); @@ -155,7 +305,7 @@ public Run executeSingleWorkflow(WorkflowOptions options, Object... args) { options = builder.build(); } - var stub = client.newUntypedWorkflowStub(methods.get(0).getName(), options); + var stub = initiatorClient().newUntypedWorkflowStub(methods.get(0).getName(), options); // Call workflow with args return new Run(methods.get(0), stub.start(args)); @@ -169,7 +319,7 @@ public Object waitForRunResult(Run run) { } public T waitForRunResult(Run run, Class type) { - var stub = client.newUntypedWorkflowStub(run.execution, Optional.empty()); + var stub = initiatorClient().newUntypedWorkflowStub(run.execution, Optional.empty()); return stub.getResult(type); } @@ -179,14 +329,14 @@ public WorkflowExecution executeWorkflow(String workflowType, Object... args) { .setTaskQueue(config.taskQueue) .setWorkflowExecutionTimeout(Duration.ofMinutes(1)); feature.workflowOptions(builder); - var stub = client.newUntypedWorkflowStub(workflowType, builder.build()); + var stub = initiatorClient().newUntypedWorkflowStub(workflowType, builder.build()); return stub.start(args); } public History getWorkflowHistory(Run run) throws Exception { var eventIter = WorkflowClientHelper.getHistory( - service, config.namespace, run.execution, config.metricsScope); + initiatorService(), config.namespace, run.execution, config.metricsScope); return History.newBuilder().addAllEvents(() -> eventIter).build(); } @@ -194,7 +344,7 @@ public Payload getWorkflowResultPayload(Run run) throws Exception { var history = getWorkflowHistory(run); var event = history.getEventsList().stream() - .filter(e -> e.hasWorkflowExecutionCompletedEventAttributes()) + .filter(HistoryEvent::hasWorkflowExecutionCompletedEventAttributes) .findFirst(); return event.get().getWorkflowExecutionCompletedEventAttributes().getResult().getPayloads(0); } @@ -203,7 +353,7 @@ public Payload getWorkflowArgumentPayload(Run run) throws Exception { var history = getWorkflowHistory(run); var event = history.getEventsList().stream() - .filter(e -> e.hasWorkflowExecutionStartedEventAttributes()) + .filter(HistoryEvent::hasWorkflowExecutionStartedEventAttributes) .findFirst(); return event.get().getWorkflowExecutionStartedEventAttributes().getInput().getPayloads(0); } @@ -215,7 +365,7 @@ public WorkflowExecutionInfo getWorkflowExecutionInfo(Run run) throws Exception .setExecution(run.execution) .build(); var exec = - this.client + this.initiatorClient() .getWorkflowServiceStubs() .blockingStub() .describeWorkflowExecution(describeRequest); @@ -241,7 +391,6 @@ public void checkCurrentAndPastHistories(Run run) throws Exception { } } - @SuppressWarnings("UnstableApiUsage") public Map loadPastHistories() throws Exception { var pkg = featureInfo.dir.replace('/', '.') + ".history"; var jsonPaths = new Reflections(pkg, Scanners.Resources).getResources(".*\\.json"); @@ -292,6 +441,20 @@ public Map loadPastHistories() throws Excep public void close() { try { workerFactory.shutdownNow(); + } catch (Throwable e) { + try { + directService.shutdownNow(); + } catch (Throwable ignored) { + } + try { + service.shutdownNow(); + } catch (Throwable ignored) { + } + throw e; + } + + try { + directService.shutdownNow(); } catch (Throwable e) { try { service.shutdownNow(); @@ -299,6 +462,7 @@ public void close() { } throw e; } + service.shutdownNow(); } @@ -314,14 +478,14 @@ public void requireNoUpdateRejectedEvents(Run run) throws Exception { var history = getWorkflowHistory(run); var event = history.getEventsList().stream() - .filter(e -> e.hasWorkflowExecutionUpdateRejectedEventAttributes()) + .filter(HistoryEvent::hasWorkflowExecutionUpdateRejectedEventAttributes) .findFirst(); Assertions.assertFalse(event.isPresent()); } public void skipIfUpdateNotSupported() { try { - client.newUntypedWorkflowStub("fake").update("also_fake", Void.class); + initiatorClient().newUntypedWorkflowStub("fake").update("also_fake", Void.class); } catch (WorkflowNotFoundException exception) { return; } catch (WorkflowServiceException exception) { @@ -339,7 +503,7 @@ public void skipIfUpdateNotSupported() { public void skipIfAsyncAcceptedUpdateNotSupported() { try { - client.newUntypedWorkflowStub("fake").startUpdate("also_fake", Void.class); + initiatorClient().newUntypedWorkflowStub("fake").startUpdate("also_fake", Void.class); } catch (WorkflowNotFoundException exception) { return; } catch (WorkflowServiceException exception) { @@ -372,4 +536,14 @@ public void retry(Supplier fn, int retries, Duration sleepBetweenRetrie } Assertions.fail("retry limit exceeded"); } + + @FunctionalInterface + public interface CheckedRunnable { + void run() throws E; + } + + @FunctionalInterface + public interface CheckedCallable { + T call() throws E; + } } diff --git a/harness/python/main.py b/harness/python/main.py index e89a24da..e2421cdf 100644 --- a/harness/python/main.py +++ b/harness/python/main.py @@ -16,11 +16,23 @@ async def run(): # Parse args parser = argparse.ArgumentParser() parser.add_argument("--server", help="The host:port of the server", required=True) + parser.add_argument( + "--direct-server", + help="The host:port of the server, bypassing the temporal-features-test-proxy", + ) parser.add_argument("--namespace", help="The namespace to use", required=True) parser.add_argument( "--client-cert-path", help="Path to a client certificate for TLS" ) parser.add_argument("--client-key-path", help="Path to a client key for TLS") + parser.add_argument( + "--summary-uri", + help="where to stream the test summary JSONL (not implemented)", + ) + parser.add_argument( + "--proxy-control-uri", + help="Base URI for simulating network outages via temporal-features-test-proxy", + ) parser.add_argument("--log-level", help="Log level", default="INFO") parser.add_argument( "features", help="Features as dir + ':' + task queue", nargs="+" diff --git a/harness/ts/main.ts b/harness/ts/main.ts index b5e4b613..f8d68529 100644 --- a/harness/ts/main.ts +++ b/harness/ts/main.ts @@ -10,9 +10,12 @@ async function run() { const program = new Command(); program .requiredOption('--server
', 'The host:port of the server') + .option('--direct-server
', 'The host:port of the server, bypassing the temporal-features-test-proxy') .requiredOption('--namespace ', 'The namespace to use') .option('--client-cert-path ', 'Path to a client certificate for TLS') .option('--client-key-path ', 'Path to a client key for TLS') + .option('--summary-uri ', 'where to stream the test summary JSONL (not implemented)') + .option('--proxy-control-uri ', 'Base URL for simulating network outages via temporal-features-test-proxy') .argument('', 'Features as dir + ":" + task queue'); const opts = program.parse(process.argv).opts<{ @@ -20,6 +23,7 @@ async function run() { namespace: string; clientCertPath: string; clientKeyPath: string; + proxyControlUri: string; featureAndTaskQueues: string[]; }>(); opts.featureAndTaskQueues = program.args; diff --git a/internal/cmd/temporal-features-test-proxy/main.go b/internal/cmd/temporal-features-test-proxy/main.go new file mode 100644 index 00000000..1b272c1c --- /dev/null +++ b/internal/cmd/temporal-features-test-proxy/main.go @@ -0,0 +1,752 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "io" + "io/fs" + "net" + "net/http" + "net/url" + "os" + "os/signal" + "path" + "strconv" + "strings" + "sync" + "time" + + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/client" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" +) + +const HelpText = `The test proxy exposes the following control endpoints: + +- POST /quit + Shut down the proxy and exit. + +- POST /restart + Gracefully shut down the gRPC server, then start it again. + + - Query param: sleep= + Forces the restart to block for the given duration; default: 0s. + + - Query param: forceful= + If true, forces a non-graceful shutdown; default: false. + +- POST /reject + Immediately reject incoming gRPC requests with UNAVAILABLE. + +- POST /accept + Accept incoming gRPC requests; this is the default. + +- POST /freeze + Block on incoming accepted gRPC requests. + +- POST /thaw + Process incoming accepted gRPC requests immediately; this is the default. +` + +var ErrUnknownCommand = errors.New("unknown command") + +var ( + flagTrace bool + flagControl string + flagListen string + flagDial string + + gListenConfig net.ListenConfig + gExitCh chan struct{} + gRootContext context.Context + gServerMutex sync.Mutex + gControlServer ControlServer + gProxyServer ProxyServer + + gStateMutex sync.Mutex + gStateCond *sync.Cond + gStateRejecting bool + gStateFrozen bool +) + +func init() { + flag.BoolVar(&flagTrace, "trace", false, "enable tracing logs") + flag.StringVar(&flagControl, "control", "", "TCP host:port to listen on for HTTP control commands") + flag.StringVar(&flagListen, "listen", "", "TCP host:port to listen on for proxying to -dial") + flag.StringVar(&flagDial, "dial", "", "TCP host:port to connect to") +} + +func main() { + flag.Parse() + + if flagControl == "" { + Fatal(1, "must specify -control") + panic(nil) + } + if flagListen == "" { + Fatal(1, "must specify -listen") + panic(nil) + } + if flagDial == "" { + Fatal(1, "must specify -dial") + panic(nil) + } + + gExitCh = make(chan struct{}) + gStateCond = sync.NewCond(&gStateMutex) + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + ctx, stop := signal.NotifyContext(ctx, os.Interrupt) + defer stop() + gRootContext = ctx + + gControlServer.Init(flagControl) + err := gControlServer.Run(ctx) + if err != nil { + Fatal(2, "%v", err) + panic(nil) + } + + needControlClose := true + defer func() { + if needControlClose { + gControlServer.ForceClose() + } + }() + + gProxyServer.Init(flagListen, flagDial) + err = gProxyServer.Run(ctx) + if err != nil { + Fatal(2, "%v", err) + panic(nil) + } + + needProxyClose := true + defer func() { + if needProxyClose { + gProxyServer.ForceClose() + } + }() + + Info("HTTP control server is running on: %s", flagControl) + Info("gRPC proxy server is running on: %s", flagListen) + Info("gRPC proxy server is connected to: %s", flagDial) + + select { + case <-gExitCh: + case <-ctx.Done(): + } + + Info("terminating") + + err = gControlServer.Shutdown(ctx) + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to gracefully shut down HTTP control server: %v", err) + } + + needControlClose = false + gControlServer.ForceClose() + + gServerMutex.Lock() + defer gServerMutex.Unlock() + + err = gProxyServer.Shutdown(ctx) + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to gracefully shut down gRPC proxy server: %v", err) + } + + needProxyClose = false + gProxyServer.ForceClose() + + Info("done") +} + +type queryKeyType struct{} + +var queryKey queryKeyType + +type ActionFunc = func(context.Context) error + +func ActionQuit(ctx context.Context) error { + close(gExitCh) + Info("/quit: proxy is shutting down") + return nil +} + +func ActionRestart(ctx context.Context) error { + q, _ := ctx.Value(queryKey).(url.Values) + + var sleep time.Duration + if q.Has("sleep") { + d, err := time.ParseDuration(q.Get("sleep")) + if err != nil { + return err + } + if d < 0 { + d = 0 + } + sleep = d + } + + var forceful bool + if q.Has("forceful") { + b, err := strconv.ParseBool(q.Get("forceful")) + if err != nil { + return err + } + forceful = b + } + + gServerMutex.Lock() + defer gServerMutex.Unlock() + + Info("/restart: restarting proxy, forceful=%t", forceful) + + mode := "gracefully" + fn := gProxyServer.Shutdown + if forceful { + mode = "forcefully" + fn = func(context.Context) error { + return gProxyServer.Close() + } + } + + err := fn(ctx) + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to %s shut down gRPC proxy server: %v", mode, err) + } + + gProxyServer.ForceClose() + + if sleep > 0 { + Info("/restart: sleeping for %v", sleep) + time.Sleep(sleep) + } + + err = gProxyServer.Run(gRootContext) + if err != nil { + close(gExitCh) + return err + } + + Info("/restart: proxy has been restarted") + return nil +} + +func ActionReject(ctx context.Context) error { + gStateMutex.Lock() + defer gStateMutex.Unlock() + + if gStateRejecting { + return nil + } + gStateRejecting = true + Info("/reject: proxy is rejecting requests") + return nil +} + +func ActionAccept(ctx context.Context) error { + gStateMutex.Lock() + defer gStateMutex.Unlock() + + if !gStateRejecting { + return nil + } + gStateRejecting = false + Info("/accept: proxy is NOT rejecting requests") + return nil +} + +func ActionFreeze(ctx context.Context) error { + gStateMutex.Lock() + defer gStateMutex.Unlock() + + if gStateFrozen { + return nil + } + gStateFrozen = true + Info("/freeze: proxy is stalling requests") + return nil +} + +func ActionThaw(ctx context.Context) error { + gStateMutex.Lock() + defer gStateMutex.Unlock() + + if !gStateFrozen { + return nil + } + gStateFrozen = false + gStateCond.Broadcast() + Info("/thaw: proxy is NOT stalling requests") + return nil +} + +func HandleHelp(w http.ResponseWriter, r *http.Request) { + if path.Clean(r.URL.Path) != "/" { + http.NotFound(w, r) + return + } + if !CheckMethod(w, r, http.MethodGet, http.MethodHead) { + return + } + body := []byte(strings.ReplaceAll(HelpText, "\n", "\r\n")) + h := w.Header() + h.Set("Content-Type", "text/plain; charset=utf-8") + h.Set("Content-Length", fmt.Sprint(len(body))) + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func HandleAction(action ActionFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !CheckMethod(w, r, http.MethodPost) { + return + } + + q, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + ctx := r.Context() + ctx = context.WithValue(ctx, queryKey, q) + if err := action(ctx); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNoContent) + } +} + +func CheckMethod(w http.ResponseWriter, r *http.Request, allowed ...string) bool { + for _, item := range allowed { + if r.Method == item { + return true + } + } + h := w.Header() + h.Set("Allow", strings.Join(allowed, ", ")) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return false +} + +type ControlServer struct { + listen string + mu sync.Mutex + cv *sync.Cond + l net.Listener + quitCh chan struct{} + server http.Server + mux http.ServeMux +} + +func (s *ControlServer) Init(listen string) { + s.listen = listen + s.cv = sync.NewCond(&s.mu) + s.l = nil + s.quitCh = nil + s.server.Handler = &s.mux + s.server.ReadTimeout = 30 * time.Second + s.server.WriteTimeout = 30 * time.Second + s.server.IdleTimeout = 60 * time.Second + s.mux.HandleFunc("/", HandleHelp) + s.mux.HandleFunc("/quit", HandleAction(ActionQuit)) + s.mux.HandleFunc("/restart", HandleAction(ActionRestart)) + s.mux.HandleFunc("/reject", HandleAction(ActionReject)) + s.mux.HandleFunc("/accept", HandleAction(ActionAccept)) + s.mux.HandleFunc("/freeze", HandleAction(ActionFreeze)) + s.mux.HandleFunc("/thaw", HandleAction(ActionThaw)) +} + +func (s *ControlServer) Run(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.quitCh != nil { + panic("BUG! ControlServer is already running") + } + + l, err := gListenConfig.Listen(ctx, "tcp", s.listen) + if err != nil { + return fmt.Errorf("failed to listen on %q: %w", s.listen, err) + } + + s.l = l + s.quitCh = make(chan struct{}) + + s.server.BaseContext = func(l net.Listener) context.Context { + return gRootContext + } + + go s.serveThread() + go s.closeThread(ctx, s.quitCh, &s.server) + return nil +} + +func (s *ControlServer) Shutdown(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.quitCh == nil { + return nil + } + + close(s.quitCh) + err := s.server.Shutdown(ctx) + for s.quitCh != nil { + s.cv.Wait() + } + return err +} + +func (s *ControlServer) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.quitCh == nil { + return nil + } + + close(s.quitCh) + err := s.server.Close() + for s.quitCh != nil { + s.cv.Wait() + } + return err +} + +func (s *ControlServer) ForceClose() { + err := s.Close() + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to stop HTTP control server: %v", err) + } +} + +func (s *ControlServer) closeThread(ctx context.Context, quitCh <-chan struct{}, closer io.Closer) { + select { + case <-ctx.Done(): + err := closer.Close() + if IsErrClosed(err) { + err = nil + } + if err != nil { + Error("failed to stop HTTP control server: %v", err) + } + case <-quitCh: + } +} + +func (s *ControlServer) serveThread() { + defer s.finish() + + err := s.server.Serve(s.l) + if IsErrClosed(err) { + err = nil + } + if err != nil { + Error("failed to serve HTTP control server: %v", err) + } + + err = s.l.Close() + if IsErrClosed(err) { + err = nil + } + if err != nil { + Error("failed to close listener for HTTP control server: %v", err) + } +} + +func (s *ControlServer) finish() { + s.mu.Lock() + s.l = nil + s.quitCh = nil + s.cv.Broadcast() + s.mu.Unlock() +} + +type ProxyServer struct { + listen string + dial string + mu sync.Mutex + cv *sync.Cond + gc *grpc.ClientConn + gs *grpc.Server + l net.Listener + wc workflowservice.WorkflowServiceClient + ws workflowservice.WorkflowServiceServer + quitCh chan struct{} +} + +func (s *ProxyServer) Init(listen, dial string) { + s.listen = listen + s.dial = dial + s.cv = sync.NewCond(&s.mu) +} + +func (s *ProxyServer) Run(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.quitCh != nil { + panic("BUG! gRPC proxy server is already running") + } + + l, err := gListenConfig.Listen(ctx, "tcp", s.listen) + if err != nil { + return fmt.Errorf("failed to listen on %q: %w", s.listen, err) + } + + needListenerClose := true + defer func() { + if needListenerClose { + err := l.Close() + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to close listener for gRPC proxy server: %v", err) + } + } + }() + + gc, err := grpc.DialContext( + ctx, + s.dial, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + if err != nil { + return fmt.Errorf("failed to dial %q: %w", s.dial, err) + } + + needClientClose := true + defer func() { + if needClientClose { + err := gc.Close() + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to close gRPC client connection: %v", err) + } + } + }() + + wc := workflowservice.NewWorkflowServiceClient(gc) + ws, err := client.NewWorkflowServiceProxyServer(client.WorkflowServiceProxyOptions{Client: wc}) + if err != nil { + return fmt.Errorf("failed to create WorkflowService proxy server: %w", err) + } + + gs := grpc.NewServer( + grpc.UnaryInterceptor(ProxyUnaryInterceptor), + grpc.StreamInterceptor(ProxyStreamInterceptor), + ) + grpc_health_v1.RegisterHealthServer(gs, &TrivialHealthServer{}) + workflowservice.RegisterWorkflowServiceServer(gs, ws) + + needClientClose = false + needListenerClose = false + s.gc = gc + s.gs = gs + s.l = l + s.wc = wc + s.ws = ws + s.quitCh = make(chan struct{}) + + go s.serveThread() + go s.closeThread(ctx, s.quitCh, s.gs) + return nil +} + +func (s *ProxyServer) Shutdown(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.quitCh == nil { + return nil + } + + close(s.quitCh) + s.gs.GracefulStop() + for s.quitCh != nil { + s.cv.Wait() + } + return nil +} + +func (s *ProxyServer) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.quitCh == nil { + return nil + } + + close(s.quitCh) + s.gs.Stop() + for s.quitCh != nil { + s.cv.Wait() + } + return nil +} + +func (s *ProxyServer) ForceClose() { + err := s.Close() + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to stop gRPC proxy server: %v", err) + } +} + +type Stopper interface { + Stop() +} + +func (s *ProxyServer) closeThread(ctx context.Context, quitCh <-chan struct{}, stopper Stopper) { + select { + case <-ctx.Done(): + stopper.Stop() + case <-quitCh: + } +} + +func (s *ProxyServer) serveThread() { + defer s.finish() + + err := s.gs.Serve(s.l) + if IsErrClosed(err) { + err = nil + } + if err != nil { + Error("failed to serve gRPC proxy server: %v", err) + } + + err = s.l.Close() + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to close listener for gRPC proxy server: %v", err) + } + + err = s.gc.Close() + if IsErrClosed(err) { + err = nil + } + if err != nil { + Warn("failed to close gRPC client connection: %v", err) + } +} + +func (s *ProxyServer) finish() { + s.mu.Lock() + s.gc = nil + s.gs = nil + s.l = nil + s.wc = nil + s.ws = nil + s.quitCh = nil + s.cv.Broadcast() + s.mu.Unlock() +} + +func AwaitPermitted() error { + gStateMutex.Lock() + defer gStateMutex.Unlock() + + if gStateRejecting { + return status.Error(codes.Unavailable, "proxy unavailable") + } + for gStateFrozen { + gStateCond.Wait() + } + return nil +} + +func ProxyUnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + if err := AwaitPermitted(); err != nil { + return nil, err + } + return handler(ctx, req) +} + +func ProxyStreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := AwaitPermitted(); err != nil { + return err + } + return handler(srv, ss) +} + +type TrivialHealthServer struct { + grpc_health_v1.UnimplementedHealthServer +} + +func (*TrivialHealthServer) Check(ctx context.Context, req *grpc_health_v1.HealthCheckRequest) (resp *grpc_health_v1.HealthCheckResponse, err error) { + return &grpc_health_v1.HealthCheckResponse{}, nil +} + +func IsErrClosed(err error) bool { + switch { + case err == nil: + return false + case errors.Is(err, io.EOF): + return true + case errors.Is(err, fs.ErrClosed): + return true + case errors.Is(err, net.ErrClosed): + return true + case errors.Is(err, http.ErrServerClosed): + return true + default: + return false + } +} + +func Trace(format string, args ...any) { + if !flagTrace { + return + } + fmt.Fprintf(os.Stderr, "trace: "+format+"\n", args...) +} + +func Info(format string, args ...any) { + fmt.Fprintf(os.Stderr, "info: "+format+"\n", args...) +} + +func Warn(format string, args ...any) { + fmt.Fprintf(os.Stderr, "warn: "+format+"\n", args...) +} + +func Error(format string, args ...any) { + fmt.Fprintf(os.Stderr, "error: "+format+"\n", args...) +} + +func Fatal(rc int, format string, args ...any) { + fmt.Fprintf(os.Stderr, "fatal: "+format+"\n", args...) + os.Exit(rc) +}