Skip to content

Commit

Permalink
feat: add gemini prompting to controllerbuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
justinsb committed Nov 5, 2024
1 parent 0605b8b commit 1c28e80
Show file tree
Hide file tree
Showing 16 changed files with 921 additions and 29 deletions.
3 changes: 3 additions & 0 deletions dev/tools/controllerbuilder/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"os"
"strings"

"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/commands/exportcsv"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/commands/generatemapper"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/commands/generatetypes"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/commands/updatetypes"
Expand Down Expand Up @@ -94,6 +95,8 @@ func Execute() {
rootCmd.AddCommand(generatetypes.BuildCommand(&generateOptions))
rootCmd.AddCommand(generatemapper.BuildCommand(&generateOptions))
rootCmd.AddCommand(updatetypes.BuildCommand(&generateOptions))
rootCmd.AddCommand(exportcsv.BuildCommand(&generateOptions))
rootCmd.AddCommand(exportcsv.BuildPromptCommand(&generateOptions))

if err := rootCmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
Expand Down
110 changes: 110 additions & 0 deletions dev/tools/controllerbuilder/pkg/commands/exportcsv/exportcsvcommand.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package exportcsv

import (
"context"
"fmt"
"os"
"strings"

"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/options"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/toolbot"

"github.com/spf13/cobra"
)

type ExportCSVOptions struct {
*options.GenerateOptions

ProtoDir string
SrcDir string
OutputDir string
}

func (o *ExportCSVOptions) BindFlags(cmd *cobra.Command) {
cmd.Flags().StringVar(&o.ProtoDir, "proto-dir", o.ProtoDir, "base directory for checkout of proto API definitions")
cmd.Flags().StringVar(&o.SrcDir, "src-dir", o.SrcDir, "base directory for source code")
cmd.Flags().StringVar(&o.OutputDir, "output-dir", o.OutputDir, "base directory for writing CSVs")
}

func BuildCommand(baseOptions *options.GenerateOptions) *cobra.Command {
opt := &ExportCSVOptions{
GenerateOptions: baseOptions,
}

cmd := &cobra.Command{
Use: "export-csv",
Short: "generate CSV from tool annotations",
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
if err := RunExportCSV(ctx, opt); err != nil {
return err
}
return nil
},
}

opt.BindFlags(cmd)

return cmd
}

func rewriteFilePath(p *string) error {
if strings.HasPrefix(*p, "~/") {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("getting home directory: %w", err)
}
*p = strings.Replace(*p, "~", homeDir, 1)
}
return nil
}

func RunExportCSV(ctx context.Context, o *ExportCSVOptions) error {
if err := rewriteFilePath(&o.ProtoDir); err != nil {
return err
}

if o.ProtoDir == "" {
return fmt.Errorf("--proto-dir is required")
}
if o.SrcDir == "" {
return fmt.Errorf("--src-dir is required")
}
if o.OutputDir == "" {
return fmt.Errorf("--output-dir is required")
}

extractor := &toolbot.ExtractToolMarkers{}
addProtoDefinition, err := toolbot.NewEnhanceWithProtoDefinition(o.ProtoDir)
if err != nil {
return err
}
x, err := toolbot.NewCSVExporter(extractor, addProtoDefinition)
if err != nil {
return err
}
if err := x.VisitCodeDir(ctx, o.SrcDir); err != nil {
return err
}

if err := x.WriteCSVForAllTools(ctx, o.OutputDir); err != nil {
return err
}

return nil
}

113 changes: 113 additions & 0 deletions dev/tools/controllerbuilder/pkg/commands/exportcsv/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package exportcsv

import (
"context"
"fmt"
"io"
"os"

"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/options"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/toolbot"
"k8s.io/klog/v2"

"github.com/spf13/cobra"
)

type PromptOptions struct {
*options.GenerateOptions

ProtoDir string
SrcDir string
}

func (o *PromptOptions) BindFlags(cmd *cobra.Command) {
cmd.Flags().StringVar(&o.SrcDir, "src-dir", o.SrcDir, "base directory for source code")
cmd.Flags().StringVar(&o.ProtoDir, "proto-dir", o.ProtoDir, "base directory for checkout of proto API definitions")
}

func BuildPromptCommand(baseOptions *options.GenerateOptions) *cobra.Command {
opt := &PromptOptions{
GenerateOptions: baseOptions,
}

cmd := &cobra.Command{
Use: "prompt",
Short: "build prompt",
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
if err := RunPrompt(ctx, opt); err != nil {
return err
}
return nil
},
}

opt.BindFlags(cmd)

return cmd
}

func RunPrompt(ctx context.Context, o *PromptOptions) error {
log := klog.FromContext(ctx)

if err := rewriteFilePath(&o.ProtoDir); err != nil {
return err
}

if o.ProtoDir == "" {
return fmt.Errorf("--proto-dir is required")
}
extractor := &toolbot.ExtractToolMarkers{}
addProtoDefinition, err := toolbot.NewEnhanceWithProtoDefinition(o.ProtoDir)
if err != nil {
return err
}
x, err := toolbot.NewCSVExporter(extractor, addProtoDefinition)
if err != nil {
return err
}

if o.SrcDir != "" {
if err := x.VisitCodeDir(ctx, o.SrcDir); err != nil {
return err
}
}

b, err := io.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("reading from stdin: %w", err)
}

dataPoints, err := x.BuildDataPoints(ctx, b)
if err != nil {
return err
}

if len(dataPoints) != 1 {
return fmt.Errorf("expected exactly one data point, got %d", len(dataPoints))
}

dataPoint := dataPoints[0]

log.Info("built data point", "dataPoint", dataPoint)

if err := x.RunGemini(ctx, dataPoint, os.Stdout); err != nil {
return fmt.Errorf("running LLM inference: %w", err)

}
return nil
}
Loading

0 comments on commit 1c28e80

Please sign in to comment.