From 7354e2da9e432240400e2eba791f660907387dae Mon Sep 17 00:00:00 2001 From: Jon Johnson <113393155+jonathanj-square@users.noreply.github.com> Date: Wed, 10 Jul 2024 10:00:36 -0700 Subject: [PATCH] fix: ftltest context option order insensivity (#2018) fixes #1866 introduces a rudimentary partial ordering scheme for context options - this is used to sort them before they get executed to build the context --- .../sql/testdata/go/database/database_test.go | 13 + go-runtime/ftl/ftltest/ftltest.go | 316 ++++++++++-------- 2 files changed, 197 insertions(+), 132 deletions(-) diff --git a/backend/controller/sql/testdata/go/database/database_test.go b/backend/controller/sql/testdata/go/database/database_test.go index 6453a531c..d5b026600 100644 --- a/backend/controller/sql/testdata/go/database/database_test.go +++ b/backend/controller/sql/testdata/go/database/database_test.go @@ -33,6 +33,19 @@ func TestDatabase(t *testing.T) { assert.Equal(t, "unit test 2", list[0]) } +func TestOptionOrdering(t *testing.T) { + ctx := ftltest.Context( + ftltest.WithDatabase(db), // <--- consumes DSNs + ftltest.WithProjectFile("ftl-project.toml"), // <--- provides DSNs + ) + + Insert(ctx, InsertRequest{Data: "unit test 1"}) + list, err := getAll(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, len(list)) + assert.Equal(t, "unit test 1", list[0]) +} + func getAll(ctx context.Context) ([]string, error) { rows, err := db.Get(ctx).Query("SELECT data FROM requests ORDER BY created_at;") if err != nil { diff --git a/go-runtime/ftl/ftltest/ftltest.go b/go-runtime/ftl/ftltest/ftltest.go index 3682676ec..f61cd09f3 100644 --- a/go-runtime/ftl/ftltest/ftltest.go +++ b/go-runtime/ftl/ftltest/ftltest.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "reflect" + "sort" "strings" _ "github.com/jackc/pgx/v5/stdlib" // SQL driver @@ -30,7 +31,17 @@ type OptionsState struct { allowDirectVerbBehavior bool } -type Option func(context.Context, *OptionsState) error +type optionRank int + +const ( + profile optionRank = iota + other +) + +type Option struct { + rank optionRank + apply func(context.Context, *OptionsState) error +} // Context suitable for use in testing FTL verbs with provided options func Context(options ...Option) context.Context { @@ -43,8 +54,12 @@ func Context(options ...Option) context.Context { ctx = internal.WithContext(ctx, newFakeFTL(ctx)) name := reflection.Module() + sort.Slice(options, func(i, j int) bool { + return options[i].rank < options[j].rank + }) + for _, option := range options { - err := option(ctx, state) + err := option.apply(ctx, state) if err != nil { panic(fmt.Sprintf("error applying option: %v", err)) } @@ -85,44 +100,48 @@ func WithProjectFile(path string) Option { preprocessingErr = fmt.Errorf("could not find default project file in $FTL_CONFIG or git") } } - return func(ctx context.Context, state *OptionsState) error { - if preprocessingErr != nil { - return preprocessingErr - } - if _, err := os.Stat(path); err != nil { - return fmt.Errorf("error accessing project file: %w", err) - } - cm, err := cf.NewDefaultConfigurationManagerFromConfig(ctx, path) - if err != nil { - return fmt.Errorf("could not set up configs: %w", err) - } - configs, err := cm.MapForModule(ctx, reflection.Module()) - if err != nil { - return fmt.Errorf("could not read configs: %w", err) - } + return Option{ + rank: profile, + apply: func(ctx context.Context, state *OptionsState) error { + if preprocessingErr != nil { + return preprocessingErr + } + if _, err := os.Stat(path); err != nil { + return fmt.Errorf("error accessing project file: %w", err) + } + cm, err := cf.NewDefaultConfigurationManagerFromConfig(ctx, path) + if err != nil { + return fmt.Errorf("could not set up configs: %w", err) + } + configs, err := cm.MapForModule(ctx, reflection.Module()) + if err != nil { + return fmt.Errorf("could not read configs: %w", err) + } - fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert - for name, data := range configs { - if err := fftl.setConfig(name, json.RawMessage(data)); err != nil { - return err + fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert + for name, data := range configs { + if err := fftl.setConfig(name, json.RawMessage(data)); err != nil { + return err + } } - } - sm, err := cf.NewDefaultSecretsManagerFromConfig(ctx, path, "") - if err != nil { - return fmt.Errorf("could not set up secrets: %w", err) - } - secrets, err := sm.MapForModule(ctx, reflection.Module()) - if err != nil { - return fmt.Errorf("could not read secrets: %w", err) - } - for name, data := range secrets { - if err := fftl.setSecret(name, json.RawMessage(data)); err != nil { - return err + sm, err := cf.NewDefaultSecretsManagerFromConfig(ctx, path, "") + if err != nil { + return fmt.Errorf("could not set up secrets: %w", err) } - } - return nil + secrets, err := sm.MapForModule(ctx, reflection.Module()) + if err != nil { + return fmt.Errorf("could not read secrets: %w", err) + } + for name, data := range secrets { + if err := fftl.setSecret(name, json.RawMessage(data)); err != nil { + return err + } + } + return nil + }, } + } // WithConfig sets a configuration for the current module @@ -134,15 +153,18 @@ func WithProjectFile(path string) Option { // // ... other options // ) func WithConfig[T ftl.ConfigType](config ftl.ConfigValue[T], value T) Option { - return func(ctx context.Context, state *OptionsState) error { - if config.Module != reflection.Module() { - return fmt.Errorf("config %v does not match current module %s", config.Module, reflection.Module()) - } - fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert - if err := fftl.setConfig(config.Name, value); err != nil { - return err - } - return nil + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + if config.Module != reflection.Module() { + return fmt.Errorf("config %v does not match current module %s", config.Module, reflection.Module()) + } + fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert + if err := fftl.setConfig(config.Name, value); err != nil { + return err + } + return nil + }, } } @@ -155,15 +177,18 @@ func WithConfig[T ftl.ConfigType](config ftl.ConfigValue[T], value T) Option { // // ... other options // ) func WithSecret[T ftl.SecretType](secret ftl.SecretValue[T], value T) Option { - return func(ctx context.Context, state *OptionsState) error { - if secret.Module != reflection.Module() { - return fmt.Errorf("secret %v does not match current module %s", secret.Module, reflection.Module()) - } - fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert - if err := fftl.setSecret(secret.Name, value); err != nil { - return err - } - return nil + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + if secret.Module != reflection.Module() { + return fmt.Errorf("secret %v does not match current module %s", secret.Module, reflection.Module()) + } + fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert + if err := fftl.setSecret(secret.Name, value); err != nil { + return err + } + return nil + }, } } @@ -176,32 +201,34 @@ func WithSecret[T ftl.SecretType](secret ftl.SecretValue[T], value T) Option { // // ... other options // ) func WithDatabase(dbHandle ftl.Database) Option { - return func(ctx context.Context, state *OptionsState) error { - fftl := internal.FromContext(ctx) - originalDSN, err := getDSNFromSecret(fftl, reflection.Module(), dbHandle.Name) - if err != nil { - return err - } + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + fftl := internal.FromContext(ctx) + originalDSN, err := getDSNFromSecret(fftl, reflection.Module(), dbHandle.Name) + if err != nil { + return err + } - // convert DSN by appending "_test" to table name - // postgres DSN format: postgresql://[user[:password]@][netloc][:port][/dbname][?param1=value1&...] - // source: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING - dsnURL, err := url.Parse(originalDSN) - if err != nil { - return fmt.Errorf("could not parse DSN: %w", err) - } - if dsnURL.Path == "" { - return fmt.Errorf("DSN for %s must include table name: %s", dbHandle.Name, originalDSN) - } - dsnURL.Path += "_test" - dsn := dsnURL.String() + // convert DSN by appending "_test" to table name + // postgres DSN format: postgresql://[user[:password]@][netloc][:port][/dbname][?param1=value1&...] + // source: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING + dsnURL, err := url.Parse(originalDSN) + if err != nil { + return fmt.Errorf("could not parse DSN: %w", err) + } + if dsnURL.Path == "" { + return fmt.Errorf("DSN for %s must include table name: %s", dbHandle.Name, originalDSN) + } + dsnURL.Path += "_test" + dsn := dsnURL.String() - // connect to db and clear out the contents of each table - sqlDB, err := sql.Open("pgx", dsn) - if err != nil { - return fmt.Errorf("could not create database %q with DSN %q: %w", dbHandle.Name, dsn, err) - } - _, err = sqlDB.ExecContext(ctx, `DO $$ + // connect to db and clear out the contents of each table + sqlDB, err := sql.Open("pgx", dsn) + if err != nil { + return fmt.Errorf("could not create database %q with DSN %q: %w", dbHandle.Name, dsn, err) + } + _, err = sqlDB.ExecContext(ctx, `DO $$ DECLARE table_name text; BEGIN @@ -212,17 +239,18 @@ func WithDatabase(dbHandle ftl.Database) Option { EXECUTE 'ALTER TABLE ' || quote_ident(table_name) || ' ENABLE TRIGGER ALL;'; END LOOP; END $$;`) - if err != nil { - return fmt.Errorf("could not clear tables in database %q: %w", dbHandle.Name, err) - } + if err != nil { + return fmt.Errorf("could not clear tables in database %q: %w", dbHandle.Name, err) + } - // replace original database with test database - replacementDB, err := modulecontext.NewTestDatabase(modulecontext.DBTypePostgres, dsn) - if err != nil { - return fmt.Errorf("could not create database %q with DSN %q: %w", dbHandle.Name, dsn, err) - } - state.databases[dbHandle.Name] = replacementDB - return nil + // replace original database with test database + replacementDB, err := modulecontext.NewTestDatabase(modulecontext.DBTypePostgres, dsn) + if err != nil { + return fmt.Errorf("could not create database %q with DSN %q: %w", dbHandle.Name, dsn, err) + } + state.databases[dbHandle.Name] = replacementDB + return nil + }, } } @@ -237,16 +265,19 @@ func WithDatabase(dbHandle ftl.Database) Option { // // ... other options // ) func WhenVerb[Req any, Resp any](verb ftl.Verb[Req, Resp], fake ftl.Verb[Req, Resp]) Option { - return func(ctx context.Context, state *OptionsState) error { - ref := reflection.FuncRef(verb) - state.mockVerbs[schema.RefKey(ref)] = func(ctx context.Context, req any) (resp any, err error) { - request, ok := req.(Req) - if !ok { - return nil, fmt.Errorf("invalid request type %T for %v, expected %v", req, ref, reflect.TypeFor[Req]()) + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + ref := reflection.FuncRef(verb) + state.mockVerbs[schema.RefKey(ref)] = func(ctx context.Context, req any) (resp any, err error) { + request, ok := req.(Req) + if !ok { + return nil, fmt.Errorf("invalid request type %T for %v, expected %v", req, ref, reflect.TypeFor[Req]()) + } + return fake(ctx, request) } - return fake(ctx, request) - } - return nil + return nil + }, } } @@ -261,12 +292,15 @@ func WhenVerb[Req any, Resp any](verb ftl.Verb[Req, Resp], fake ftl.Verb[Req, Re // // ... other options // ) func WhenSource[Resp any](source ftl.Source[Resp], fake func(ctx context.Context) (resp Resp, err error)) Option { - return func(ctx context.Context, state *OptionsState) error { - ref := reflection.FuncRef(source) - state.mockVerbs[schema.RefKey(ref)] = func(ctx context.Context, req any) (resp any, err error) { - return fake(ctx) - } - return nil + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + ref := reflection.FuncRef(source) + state.mockVerbs[schema.RefKey(ref)] = func(ctx context.Context, req any) (resp any, err error) { + return fake(ctx) + } + return nil + }, } } @@ -281,16 +315,19 @@ func WhenSource[Resp any](source ftl.Source[Resp], fake func(ctx context.Context // // ... other options // ) func WhenSink[Req any](sink ftl.Sink[Req], fake func(ctx context.Context, req Req) error) Option { - return func(ctx context.Context, state *OptionsState) error { - ref := reflection.FuncRef(sink) - state.mockVerbs[schema.RefKey(ref)] = func(ctx context.Context, req any) (resp any, err error) { - request, ok := req.(Req) - if !ok { - return nil, fmt.Errorf("invalid request type %T for %v, expected %v", req, ref, reflect.TypeFor[Req]()) + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + ref := reflection.FuncRef(sink) + state.mockVerbs[schema.RefKey(ref)] = func(ctx context.Context, req any) (resp any, err error) { + request, ok := req.(Req) + if !ok { + return nil, fmt.Errorf("invalid request type %T for %v, expected %v", req, ref, reflect.TypeFor[Req]()) + } + return ftl.Unit{}, fake(ctx, request) } - return ftl.Unit{}, fake(ctx, request) - } - return nil + return nil + }, } } @@ -304,12 +341,15 @@ func WhenSink[Req any](sink ftl.Sink[Req], fake func(ctx context.Context, req Re // }), // ) func WhenEmpty(empty ftl.Empty, fake func(ctx context.Context) (err error)) Option { - return func(ctx context.Context, state *OptionsState) error { - ref := reflection.FuncRef(empty) - state.mockVerbs[schema.RefKey(ref)] = func(ctx context.Context, req any) (resp any, err error) { - return ftl.Unit{}, fake(ctx) - } - return nil + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + ref := reflection.FuncRef(empty) + state.mockVerbs[schema.RefKey(ref)] = func(ctx context.Context, req any) (resp any, err error) { + return ftl.Unit{}, fake(ctx) + } + return nil + }, } } @@ -317,9 +357,12 @@ func WhenEmpty(empty ftl.Empty, fake func(ctx context.Context) (err error)) Opti // // Any overrides provided by calling WhenVerb(...) will take precedence func WithCallsAllowedWithinModule() Option { - return func(ctx context.Context, state *OptionsState) error { - state.allowDirectVerbBehavior = true - return nil + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + state.allowDirectVerbBehavior = true + return nil + }, } } @@ -334,10 +377,13 @@ func WithCallsAllowedWithinModule() Option { // // ... other options // ) func WhenMap[T, U any](mapper *ftl.MapHandle[T, U], fake func(context.Context) (U, error)) Option { - return func(ctx context.Context, state *OptionsState) error { - fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert - addMapMock(fftl, mapper, fake) - return nil + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert + addMapMock(fftl, mapper, fake) + return nil + }, } } @@ -346,10 +392,13 @@ func WhenMap[T, U any](mapper *ftl.MapHandle[T, U], fake func(context.Context) ( // // Any overrides provided by calling WhenMap(...) will take precedence. func WithMapsAllowed() Option { - return func(ctx context.Context, state *OptionsState) error { - fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert - fftl.startAllowingMapCalls() - return nil + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert + fftl.startAllowingMapCalls() + return nil + }, } } @@ -384,10 +433,13 @@ func getDSNFromSecret(ftl internal.FTL, module, name string) (string, error) { // // ... other options // ) func WithSubscriber[E any](subscription ftl.SubscriptionHandle[E], sink ftl.Sink[E]) Option { - return func(ctx context.Context, state *OptionsState) error { - fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert - addSubscriber(fftl.pubSub, subscription, sink) - return nil + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + fftl := internal.FromContext(ctx).(*fakeFTL) //nolint:forcetypeassert + addSubscriber(fftl.pubSub, subscription, sink) + return nil + }, } }