diff --git a/cmd/ftl/cmd_schema_get.go b/cmd/ftl/cmd_schema_get.go index 2807adc34..94398ed51 100644 --- a/cmd/ftl/cmd_schema_get.go +++ b/cmd/ftl/cmd_schema_get.go @@ -4,8 +4,10 @@ import ( "context" "fmt" "os" + "slices" "connectrpc.com/connect" + "golang.org/x/exp/maps" "google.golang.org/protobuf/proto" ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" @@ -15,7 +17,8 @@ import ( ) type getSchemaCmd struct { - Protobuf bool `help:"Output the schema as binary protobuf."` + Protobuf bool `help:"Output the schema as binary protobuf."` + Modules []string `help:"Modules to include" type:"string" optional:""` } func (g *getSchemaCmd) Run(ctx context.Context, client ftlv1connect.ControllerServiceClient) error { @@ -26,25 +29,47 @@ func (g *getSchemaCmd) Run(ctx context.Context, client ftlv1connect.ControllerSe if g.Protobuf { return g.generateProto(resp) } + remainingNames := make(map[string]bool) + for _, name := range g.Modules { + remainingNames[name] = true + } for resp.Receive() { msg := resp.Msg() module, err := schema.ModuleFromProto(msg.Schema) - if err != nil { - return fmt.Errorf("%s: %w", "invalid module schema", err) + if len(g.Modules) == 0 || remainingNames[msg.Schema.Name] { + if err != nil { + return fmt.Errorf("%s: %w", "invalid module schema", err) + } + fmt.Println(module) + delete(remainingNames, msg.Schema.Name) } - fmt.Println(module) if !msg.More { break } } - return resp.Err() + if err := resp.Err(); err != nil { + return resp.Err() + } + missingNames := maps.Keys(remainingNames) + slices.Sort(missingNames) + if len(missingNames) > 0 { + return fmt.Errorf("missing modules: %v", missingNames) + } + return nil } func (g *getSchemaCmd) generateProto(resp *connect.ServerStreamForClient[ftlv1.PullSchemaResponse]) error { + remainingNames := make(map[string]bool) + for _, name := range g.Modules { + remainingNames[name] = true + } schema := &schemapb.Schema{} for resp.Receive() { msg := resp.Msg() - schema.Modules = append(schema.Modules, msg.Schema) + if len(g.Modules) == 0 || remainingNames[msg.Schema.Name] { + schema.Modules = append(schema.Modules, msg.Schema) + delete(remainingNames, msg.Schema.Name) + } if !msg.More { break } @@ -57,5 +82,13 @@ func (g *getSchemaCmd) generateProto(resp *connect.ServerStreamForClient[ftlv1.P return err } _, err = os.Stdout.Write(pb) - return err + if err != nil { + return err + } + missingNames := maps.Keys(remainingNames) + slices.Sort(missingNames) + if len(missingNames) > 0 { + return fmt.Errorf("missing modules: %v", missingNames) + } + return nil }