Skip to content

Commit

Permalink
feat: Add capability to replace proto contents in patch-proto
Browse files Browse the repository at this point in the history
Use this new capability to codify the Apigee proto hack that was
originally made (manually) in
#3183.

This addresses the suggested change in
#3183 (review).
  • Loading branch information
jasonvigil committed Dec 10, 2024
1 parent 3e6eb1a commit 66bb80c
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 23 deletions.
11 changes: 11 additions & 0 deletions mockgcp/apply-proto-patches.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,14 @@ cd tools/patch-proto
# // Optional. The configuration for Private Service Connect (PSC) for the cluster.
# PscConfig psc_config = 31 [(google.api.field_behavior) = OPTIONAL];
# EOF

go run . --file ${REPO_ROOT}/mockgcp/apis/mockgcp/cloud/apigee/v1/service.proto --service "ProjectsServer" --mode "replace" <<EOF
// Provisions a new Apigee organization with a functioning runtime. This is the standard way to create trial organizations for a free Apigee trial.
rpc ProvisionOrganizationProject(ProvisionOrganizationProjectRequest) returns (.google.longrunning.Operation) {
option (google.api.http) = {
post: "/v1/{name=projects/*}:provisionOrganization"
body: "project"
};
};
EOF
2 changes: 2 additions & 0 deletions mockgcp/tools/patch-proto/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ require (
github.com/go-logr/logr v1.4.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/thediveo/enumflag/v2 v2.0.5 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
)
4 changes: 4 additions & 0 deletions mockgcp/tools/patch-proto/go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

136 changes: 113 additions & 23 deletions mockgcp/tools/patch-proto/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
sitter "github.com/smacker/go-tree-sitter"
"github.com/smacker/go-tree-sitter/protobuf"
"github.com/spf13/cobra"
"github.com/thediveo/enumflag/v2"
)

func main() {
Expand All @@ -43,35 +44,67 @@ func run(ctx context.Context) error {

cmd := cobra.Command{
Use: "proto-patch",
Short: "patches a message in a proto file",
Long: `Inserts the contents of stdin at the end of the message in the proto file.`,
Short: "patches a proto file",
Long: `Patches the contents of stdin into a proto file.`,
RunE: func(cmd *cobra.Command, args []string) error {
insert, err := io.ReadAll(os.Stdin)
patch, err := io.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("reading stdin: %w", err)
}
opt.Insertion = insert
opt.Patch = patch

return RunProtoPatch(ctx, opt)
},
}
cmd.Flags().StringVar(&opt.ProtoPath, "file", opt.ProtoPath, "path to proto file to patch")
cmd.Flags().StringVar(&opt.Message, "message", opt.Message, "message file to patch")
cmd.Flags().VarP(
enumflag.New(&opt.Mode, "mode", ProtoPatchModeIDs, enumflag.EnumCaseInsensitive),
"mode", "m",
"patch mode; can be 'append' or 'replace'")

cmd.Flags().StringVar(&opt.Message, "message", opt.Message, "message name to patch")
cmd.Flags().StringVar(&opt.Service, "service", opt.Service, "service name to patch")

cmd.MarkFlagsOneRequired("message", "service")
cmd.MarkFlagsMutuallyExclusive("message", "service")

return cmd.Execute()
}

type ProtoPatchMode enumflag.Flag

const (
ProtoPatchModeAppend ProtoPatchMode = iota
ProtoPatchModeReplace
)

var ProtoPatchModeIDs = map[ProtoPatchMode][]string{
ProtoPatchModeAppend: {"append"},
ProtoPatchModeReplace: {"replace"},
}

type ProtoPatchOptions struct {
ProtoPath string
Message string
Insertion []byte
Service string
Mode ProtoPatchMode
Patch []byte
}

func RunProtoPatch(ctx context.Context, opt ProtoPatchOptions) error {
x := &insertPatchIntoMessage{
Message: opt.Message,
Insertion: opt.Insertion,
x := &patchProto{
Patch: opt.Patch,
Mode: opt.Mode,
}

if opt.Message != "" {
x.Id = ProtoIdentifierMessage
x.Name = opt.Message
} else if opt.Service != "" {
x.Id = ProtoIdentifierService
x.Name = opt.Service
}

protoPath := opt.ProtoPath

srcBytes, err := os.ReadFile(protoPath)
Expand All @@ -96,7 +129,7 @@ func RunProtoPatch(ctx context.Context, opt ProtoPatchOptions) error {
}

if x.Out == nil {
return fmt.Errorf("message %q not found in file %q", x.Message, protoPath)
return fmt.Errorf("identifier %q not found in file %q", x.Name, protoPath)
}
if err := os.WriteFile(protoPath, x.Out, 0644); err != nil {
return fmt.Errorf("writing to file %q: %w", protoPath, err)
Expand All @@ -105,16 +138,27 @@ func RunProtoPatch(ctx context.Context, opt ProtoPatchOptions) error {
return nil
}

type insertPatchIntoMessage struct {
Source []byte
Insertion []byte
Message string
Errors []error
type ProtoIdentifier string

const (
ProtoIdentifierMessage ProtoIdentifier = "message"
ProtoIdentifierService ProtoIdentifier = "service"
)

type patchProto struct {
Mode ProtoPatchMode
Id ProtoIdentifier

Source []byte
Patch []byte
Name string

Errors []error

Out []byte
}

func (x *insertPatchIntoMessage) VisitNode(depth int, cursor *sitter.TreeCursor) {
func (x *patchProto) VisitNode(depth int, cursor *sitter.TreeCursor) {
node := cursor.CurrentNode()

// fmt.Printf("%s[%d:%s] %s\n", strings.Repeat(" ", depth), node.Symbol(), node.Type(), node.Content(x.Source))
Expand Down Expand Up @@ -145,7 +189,16 @@ func (x *insertPatchIntoMessage) VisitNode(depth int, cursor *sitter.TreeCursor)
case "message":
// e.g. message MyMessage { ... }
descend = false
x.VisitMessage(depth, cursor.CurrentNode())
if x.Id == ProtoIdentifierMessage {
x.VisitMessage(depth, cursor.CurrentNode())
}

case "service":
// e.g. service MyService { ... }
descend = false
if x.Id == ProtoIdentifierService {
x.VisitService(depth, cursor.CurrentNode())
}

default:
x.Errors = append(x.Errors, fmt.Errorf("unknown top-level node %q", protobuf.GetLanguage().SymbolName(node.Symbol())))
Expand All @@ -162,7 +215,7 @@ func (x *insertPatchIntoMessage) VisitNode(depth int, cursor *sitter.TreeCursor)
}
}

func (x *insertPatchIntoMessage) VisitMessage(depth int, node *sitter.Node) {
func (x *patchProto) VisitMessage(depth int, node *sitter.Node) {
klog.V(2).Infof("%s[%d:%s] %s\n", strings.Repeat(" ", depth), node.Symbol(), node.Type(), node.Content(x.Source))

messageName := ""
Expand All @@ -182,21 +235,58 @@ func (x *insertPatchIntoMessage) VisitMessage(depth int, node *sitter.Node) {
}
}

if messageName == x.Message {
if messageName == x.Name {
if messageBody == nil {
x.Errors = append(x.Errors, fmt.Errorf("could not find message definition for message %q", messageName))
return
}

var out bytes.Buffer
out.Write(x.Source[:messageBody.StartByte()])
messageBodyContents := string(x.Source[messageBody.StartByte():messageBody.EndByte()])
messageBodyContents = strings.TrimSuffix(messageBodyContents, "}")
out.WriteString(messageBodyContents)
out.Write(x.Insertion)
if x.Mode == ProtoPatchModeAppend {
messageBodyContents := string(x.Source[messageBody.StartByte():messageBody.EndByte()])
messageBodyContents = strings.TrimSuffix(messageBodyContents, "}")
out.WriteString(messageBodyContents)
} else if x.Mode == ProtoPatchModeReplace {
out.WriteString("{\n")
}
out.Write(x.Patch)
out.WriteString("\n}")
out.Write(x.Source[messageBody.EndByte():])

x.Out = out.Bytes()
}
}

func (x *patchProto) VisitService(depth int, node *sitter.Node) {
klog.V(2).Infof("%s[%d:%s] %s\n", strings.Repeat(" ", depth), node.Symbol(), node.Type(), node.Content(x.Source))

childCount := int(node.ChildCount())
var serviceName *sitter.Node
var serviceBodyIdx int
for i := 0; i < childCount; i++ {
child := node.Child(i)
if protobuf.GetLanguage().SymbolName(child.Symbol()) == "service_name" {
serviceName = child
serviceBodyIdx = i + 1
break
}
}

if serviceName.Content(x.Source) == x.Name {
serviceBodyStart := node.Child(serviceBodyIdx)
lastChild := node.Child(childCount - 1)

var out bytes.Buffer
out.Write(x.Source[:serviceBodyStart.EndByte()])
if x.Mode == ProtoPatchModeAppend {
out.Write(x.Source[serviceBodyStart.EndByte():lastChild.StartByte()])
}
out.WriteString("\n")
out.Write(x.Patch)
out.WriteString(string(x.Source[lastChild.StartByte():lastChild.EndByte()]))
out.Write(x.Source[lastChild.EndByte():])

x.Out = out.Bytes()
}
}

0 comments on commit 66bb80c

Please sign in to comment.