Skip to content

Commit

Permalink
Merge pull request GoogleCloudPlatform#3337 from jingyih/refactor
Browse files Browse the repository at this point in the history
tool: correctly insert field into Spec or ObservedState.
  • Loading branch information
google-oss-prow[bot] authored Dec 12, 2024
2 parents 05b63d0 + cf7727e commit 4db732a
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,24 @@
package updatetypes

import (
"bytes"
"context"
"fmt"
"os"
"strings"

"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/codegen"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/gocode"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/options"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/typeupdater"

"github.com/spf13/cobra"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/klog/v2"
)

const kccProtoPrefix = "+kcc:proto="

type UpdateTypeOptions struct {
*options.GenerateOptions

parentMessageFullName string
newField string
ignoredFields string // TODO: could be part of GenerateOptions
apiDirectory string
goPackagePath string
parentNessage string // The fully qualified name of the parent prroto message of the field to be inserted
fieldToInsert string
ignoredFields string // TODO: could be part of GenerateOptions
apiDirectory string
goPackagePath string
}

func (o *UpdateTypeOptions) InitDefaults() error {
Expand All @@ -56,8 +46,9 @@ func (o *UpdateTypeOptions) InitDefaults() error {
}

func (o *UpdateTypeOptions) BindFlags(cmd *cobra.Command) {
cmd.Flags().StringVar(&o.parentMessageFullName, "parent-message-full-name", o.parentMessageFullName, "Fully qualified name of the proto message holding the new field")
cmd.Flags().StringVar(&o.newField, "new-field", o.newField, "Name of the new field")
cmd.Flags().StringVar(&o.parentNessage, "parent-message", o.parentNessage, "Fully qualified name of the proto message holding the new field. e.g. `google.cloud.bigquery.datatransfer.v1.TransferConfig`")
cmd.Flags().StringVar(&o.fieldToInsert, "field-to-insert", o.fieldToInsert, "Name of the new field to be inserted, e.g. `schedule_options_v2`")
// TODO: Update this flag to accept a file path pointing to the ignored fields YAML file.
cmd.Flags().StringVar(&o.ignoredFields, "ignored-fields", o.ignoredFields, "Comma-separated list of fields to ignore")
cmd.Flags().StringVar(&o.apiDirectory, "api-dir", o.apiDirectory, "Base directory for APIs")
cmd.Flags().StringVar(&o.goPackagePath, "api-go-package-path", o.goPackagePath, "Package path")
Expand All @@ -77,8 +68,8 @@ func BuildCommand(baseOptions *options.GenerateOptions) *cobra.Command {
Use: "update-types",
Short: "update KRM types for a proto service",
RunE: func(cmd *cobra.Command, args []string) error {
updater := NewTypeUpdater(opt)
if err := updater.Run(); err != nil {
ctx := cmd.Context()
if err := runTypeUpdater(ctx, opt); err != nil {
return err
}
return nil
Expand All @@ -90,159 +81,22 @@ func BuildCommand(baseOptions *options.GenerateOptions) *cobra.Command {
return cmd
}

type TypeUpdater struct {
opts *UpdateTypeOptions
newField newProtoField
dependentMessages map[string]protoreflect.MessageDescriptor // key: fully qualified name of proto message
generatedGoField generatedGoField // TODO: support multiple new fields
generatedGoStructs []generatedGoStruct
}

type newProtoField struct {
field protoreflect.FieldDescriptor
parentMessage protoreflect.MessageDescriptor
}

type generatedGoField struct {
parentMessage string // fully qualified name of the parent proto message of this field
content []byte // the content of the generated Go field
}

type generatedGoStruct struct {
name string // fully qualified name of the proto message
content []byte // the content of the generated Go struct
}

func NewTypeUpdater(opts *UpdateTypeOptions) *TypeUpdater {
return &TypeUpdater{
opts: opts,
}
}

func (u *TypeUpdater) Run() error {
// 1. find new field and its dependent proto messages that needs to be generated
if err := u.analyze(); err != nil {
return nil
func runTypeUpdater(ctx context.Context, opt *UpdateTypeOptions) error {
if opt.apiDirectory == "" {
return fmt.Errorf("--api-dir is required")
}

// 2. generate Go types for the new field and its dependent proto messages
if err := u.generate(); err != nil {
return err
}

// 3. insert the generated Go code back to files
if err := u.insertGoField(); err != nil {
return err
typeUpdaterOpts := &typeupdater.UpdaterOptions{
ProtoSourcePath: opt.GenerateOptions.ProtoSourcePath,
ParentMessageFullName: opt.parentNessage,
FieldToInsert: opt.fieldToInsert,
IgnoredFields: opt.ignoredFields,
APIDirectory: opt.apiDirectory,
GoPackagePath: opt.goPackagePath,
}
if err := u.insertGoMessages(); err != nil {
return err
}

return nil
}

// anaylze finds the new field, its parent message, and all dependent messages that need to be generated.
func (u *TypeUpdater) analyze() error {
parentMessage, newField, err := findNewField(u.opts.ProtoSourcePath, u.opts.parentMessageFullName, u.opts.newField)
if err != nil {
return err
}
u.newField = newProtoField{
field: newField,
parentMessage: parentMessage,
}

msgs, err := findDependentMsgs(newField, sets.NewString(strings.Split(u.opts.ignoredFields, ",")...))
if err != nil {
return err
}

codegen.RemoveNotMappedToGoStruct(msgs)

if err := removeAlreadyGenerated(u.opts.goPackagePath, u.opts.apiDirectory, msgs); err != nil {
updater := typeupdater.NewTypeUpdater(typeUpdaterOpts)
if err := updater.Run(); err != nil {
return err
}
u.dependentMessages = msgs
return nil
}

// findNewField locates the parent message and the new field in the proto file
func findNewField(pbSourcePath, parentMsgFullName, newFieldName string) (protoreflect.MessageDescriptor, protoreflect.FieldDescriptor, error) {
fileData, err := os.ReadFile(pbSourcePath)
if err != nil {
return nil, nil, fmt.Errorf("reading %q: %w", pbSourcePath, err)
}

fds := &descriptorpb.FileDescriptorSet{}
if err := proto.Unmarshal(fileData, fds); err != nil {
return nil, nil, fmt.Errorf("unmarshalling %q: %w", pbSourcePath, err)
}

files, err := protodesc.NewFiles(fds)
if err != nil {
return nil, nil, fmt.Errorf("building file description: %w", err)
}

// Find the parent message
messageDesc, err := files.FindDescriptorByName(protoreflect.FullName(parentMsgFullName))
if err != nil {
return nil, nil, err
}
msgDesc, ok := messageDesc.(protoreflect.MessageDescriptor)
if !ok {
return nil, nil, fmt.Errorf("unexpected descriptor type: %T", msgDesc)
}

// Find the new field in parent message
fieldDesc := msgDesc.Fields().ByName(protoreflect.Name(newFieldName))
if fieldDesc == nil {
return nil, nil, fmt.Errorf("field not found in message")
}

return msgDesc, fieldDesc, nil
}

// findDependentMsgs finds all dependent proto messages for the given field, ignoring specified fields
func findDependentMsgs(field protoreflect.FieldDescriptor, ignoredProtoFields sets.String) (map[string]protoreflect.MessageDescriptor, error) {
deps := make(map[string]protoreflect.MessageDescriptor)
codegen.FindDependenciesForField(field, deps, ignoredProtoFields)
return deps, nil
}

// removeAlreadyGenerated removes proto messages that have already been generated (including manually edited)
func removeAlreadyGenerated(goPackagePath, outputAPIDirectory string, targets map[string]protoreflect.MessageDescriptor) error {
packages, err := gocode.LoadPackageTree(goPackagePath, outputAPIDirectory)
if err != nil {
return err
}
for _, p := range packages {
for _, s := range p.Structs {
if annotation := s.GetAnnotation("+kcc:proto"); annotation != "" {
delete(targets, annotation)
}
}
}
return nil
}

func (u *TypeUpdater) generate() error {
var buf bytes.Buffer
klog.Infof("generate Go code for field %s", u.newField.field.Name())
codegen.WriteField(&buf, u.newField.field, u.newField.parentMessage, 0)
u.generatedGoField = generatedGoField{
parentMessage: string(u.newField.parentMessage.FullName()),
content: buf.Bytes(),
}

for _, msg := range u.dependentMessages {
var buf bytes.Buffer
klog.Infof("generate Go code for messge %s", msg.FullName())
codegen.WriteMessage(&buf, msg)
u.generatedGoStructs = append(u.generatedGoStructs,
generatedGoStruct{
name: string(msg.FullName()),
content: buf.Bytes(),
})
}
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package updatetypes
package typeupdater

import (
"fmt"
Expand All @@ -25,16 +25,24 @@ import (
"strings"

"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/gocode"
"github.com/GoogleCloudPlatform/k8s-config-connector/pkg/controller/direct/common"
"google.golang.org/genproto/googleapis/api/annotations"

"k8s.io/klog/v2"
)

type target struct {
goName string
endPos int
}

func (u *TypeUpdater) insertGoField() error {
klog.Infof("inserting the generated Go code for field %s", u.newField.field.Name())
klog.Infof("inserting the generated Go code for field %s", u.newField.proto.Name())

targetComment := fmt.Sprintf("+kcc:proto=%s", u.generatedGoField.parentMessage)
targetComment := fmt.Sprintf("+kcc:proto=%s", u.newField.parent.FullName())
outputOnly := common.IsFieldBehavior(u.newField.proto, annotations.FieldBehavior_OUTPUT_ONLY)

filepath.WalkDir(u.opts.apiDirectory, func(path string, d fs.DirEntry, err error) error {
filepath.WalkDir(u.opts.APIDirectory, func(path string, d fs.DirEntry, err error) error {
if err != nil || d.IsDir() || filepath.Ext(path) != ".go" {
return nil
}
Expand All @@ -55,13 +63,12 @@ func (u *TypeUpdater) insertGoField() error {
// use a CommentMap to associate comments with nodes
docMap := gocode.NewDocMap(fset, file)

// find the target Go struct and its ending position in the source
var endPos int
// find the target Go struct and the ending position in the source
// there are 2 cases considered.
// - case 1, there is only 1 matching target.
// - case 2, there are two matching targets (Spec and ObservedState).
var targets []target
ast.Inspect(file, func(n ast.Node) bool {
if endPos != 0 {
return false // already found the target
}

ts, ok := n.(*ast.TypeSpec)
if !ok {
return true
Expand All @@ -80,19 +87,35 @@ func (u *TypeUpdater) insertGoField() error {
return true // empty struct? this should not happen
}

klog.Infof("found target Go struct %s", ts.Name.Name)

endPos = int(fset.Position(st.End()).Offset)
return false // stop searching, we found the target Go struct
klog.Infof("found potential target Go struct %s", ts.Name.Name)
targets = append(targets, target{
goName: ts.Name.Name,
endPos: int(fset.Position(st.End()).Offset),
})
return true // continue searching for potential target Go struct
})

// if the target Go struct was found, modify the source bytes
if endPos != 0 {
var chosenTarget *target
if len(targets) == 0 { // no target, continue to next file
return nil
} else if len(targets) == 1 { // case 1, one matching Go struct
chosenTarget = &targets[0]
} else if len(targets) == 2 { // case 2, Spec/ObservedState pair
for _, t := range targets {
if !outputOnly && strings.HasSuffix(t.goName, "Spec") ||
outputOnly && strings.HasSuffix(t.goName, "ObservedState") {
chosenTarget = &t
break
}
}
}

if chosenTarget != nil { // target Go struct was found, modify the source bytes
var newSrcBytes []byte
// TODO: ues the same field ordering as in proto message
newSrcBytes = append(newSrcBytes, srcBytes[:endPos-1]...) // up to before '}'
newSrcBytes = append(newSrcBytes, u.generatedGoField.content...) // insert new field
newSrcBytes = append(newSrcBytes, srcBytes[endPos-1:]...) // include the '}'
// TODO: use the same field ordering as in proto message?
newSrcBytes = append(newSrcBytes, srcBytes[:chosenTarget.endPos-1]...) // up to before '}'
newSrcBytes = append(newSrcBytes, u.newField.generatedContent...) // insert new field
newSrcBytes = append(newSrcBytes, srcBytes[chosenTarget.endPos-1:]...) // include the '}'

if err := os.WriteFile(path, newSrcBytes, d.Type()); err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package updatetypes
package typeupdater

import (
"context"
Expand All @@ -28,7 +28,7 @@ import (
)

func (u *TypeUpdater) insertGoFieldGemini() error {
klog.Infof("inserting the generated Go code for field %s", u.newField.field.Name())
klog.Infof("inserting the generated Go code for field %s", u.newField.proto.Name())
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
if err != nil {
Expand All @@ -50,20 +50,20 @@ func (u *TypeUpdater) insertGoFieldGemini() error {
Could you find the Go struct which has comment "+kcc:proto=%s" with no following suffix,
and insert the Go field into the found Go struct.
In your response, only include what is asked for.
`, u.generatedGoField.parentMessage)),
`, u.newField.parent.FullName())),
},
Role: "user",
},
}
// provide the content of the new Go field
session.History = append(session.History, &genai.Content{
Parts: []genai.Part{
genai.Text(fmt.Sprintf("new Go field:\n%s\n\n", u.generatedGoField.content)),
genai.Text(fmt.Sprintf("new Go field:\n%s\n\n", u.newField.generatedContent)),
},
Role: "user",
})
// provide content of the existing Go files
files, err := listFiles(u.opts.apiDirectory)
files, err := listFiles(u.opts.APIDirectory)
if err != nil {
return fmt.Errorf("error listing files: %w", err)
}
Expand Down
Loading

0 comments on commit 4db732a

Please sign in to comment.