Skip to content

Commit

Permalink
Merge pull request #315 from toctan/fix-grpc-transcoding-syntax
Browse files Browse the repository at this point in the history
Support gRPC transcoding syntax with '.=*' in path template
  • Loading branch information
zaquestion authored Apr 16, 2021
2 parents 07f56d6 + e9b5e99 commit 9bfafba
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 47 deletions.
37 changes: 37 additions & 0 deletions gengokit/httptransport/embeddable_funcs.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package httptransport

import (
"encoding/json"
"fmt"
"strings"
)
Expand Down Expand Up @@ -69,3 +70,39 @@ func RemoveBraces(val string) string {
val = strings.Replace(val, "}", "", -1)
return val
}

// encodePathParams encodes `mux.Vars()` with dot notations into JSON objects
// to be unmarshaled into non-basetype fields.
// e.g. {"book.name": "books/1"} -> {"book": {"name": "books/1"}}
func encodePathParams(vars map[string]string) map[string]string {
var recur func(path, value string, data map[string]interface{})
recur = func(path, value string, data map[string]interface{}) {
parts := strings.SplitN(path, ".", 2)
key := parts[0]
if len(parts) == 1 {
data[key] = value
} else {
if _, ok := data[key]; !ok {
data[key] = make(map[string]interface{})
}
recur(parts[1], value, data[key].(map[string]interface{}))
}
}

data := make(map[string]interface{})
for key, val := range vars {
recur(key, val, data)
}

ret := make(map[string]string)
for key, val := range data {
switch val := val.(type) {
case string:
ret[key] = val
case map[string]interface{}:
m, _ := json.Marshal(val)
ret[key] = string(m)
}
}
return ret
}
69 changes: 69 additions & 0 deletions gengokit/httptransport/embeddable_funcs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package httptransport

import (
"reflect"
"testing"
)

func TestEncodePathParams(t *testing.T) {
tests := []struct {
name string
vars map[string]string
want map[string]string
}{
{
name: "simple",
vars: map[string]string{
"parent": "shelves/shelf1",
},
want: map[string]string{
"parent": "shelves/shelf1",
},
},
{
name: "dot notation - single value",
vars: map[string]string{
"book.name": "shelves/shelf1/books/book1",
},
want: map[string]string{
"book": `{"name":"shelves/shelf1/books/book1"}`,
},
},
{
name: "dot notation - multiple values",
vars: map[string]string{
"book.name": "shelves/shelf1/books/book1",
"book.version": "v1",
},
want: map[string]string{
"book": `{"name":"shelves/shelf1/books/book1","version":"v1"}`,
},
},
{
name: "dot notation - multiple levels",
vars: map[string]string{
"book.version.name": "versions/v1",
},
want: map[string]string{
"book": `{"version":{"name":"versions/v1"}}`,
},
},
{
name: "dot notation - multiple values in multiple levels",
vars: map[string]string{
"book.name": "shelves/shelf1/books/book1",
"book.version.name": "versions/v1",
},
want: map[string]string{
"book": `{"name":"shelves/shelf1/books/book1","version":{"name":"versions/v1"}}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := encodePathParams(tt.vars); !reflect.DeepEqual(got, tt.want) {
t.Errorf("encodePathParams() = %v, want %v", got, tt.want)
}
})
}
}
50 changes: 39 additions & 11 deletions gengokit/httptransport/httptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"bytes"
"fmt"
"go/format"
"regexp"
"strconv"
"strings"
"text/template"
Expand Down Expand Up @@ -71,7 +72,7 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding {
binding := meth.Bindings[i]
nBinding := Binding{
Label: meth.Name + EnglishNumber(i),
PathTemplate: binding.Path,
PathTemplate: getMuxPathTemplate(binding.Path),
BasePath: basePath(binding.Path),
Verb: binding.Verb,
}
Expand Down Expand Up @@ -190,7 +191,11 @@ func GenServerTemplate(exec interface{}) (string, error) {
if err != nil {
return "", err
}
code = FormatCode(code)
encodeFuncSource, err := FuncSourceCode(encodePathParams)
if err != nil {
return "", err
}
code = FormatCode(code + encodeFuncSource)
return code, nil
}

Expand Down Expand Up @@ -240,6 +245,12 @@ func (b *Binding) GenClientEncode() (string, error) {
// "fmt.Sprint(req.A)",
// }
func (b *Binding) PathSections() []string {
path := b.PathTemplate
re := regexp.MustCompile(`{.+:.+}`)
path = re.ReplaceAllStringFunc(path, func(v string) string {
return strings.Split(v, ":")[0] + "}"
})

isEnum := make(map[string]struct{})
for _, v := range b.Fields {
if v.IsEnum {
Expand All @@ -248,16 +259,22 @@ func (b *Binding) PathSections() []string {
}

rv := []string{}
parts := strings.Split(b.PathTemplate, "/")
parts := strings.Split(path, "/")
for _, part := range parts {
if len(part) > 2 && part[0] == '{' && part[len(part)-1] == '}' {
name := RemoveBraces(part)
if _, ok := isEnum[gogen.CamelCase(name)]; ok {
convert := fmt.Sprintf("fmt.Sprintf(\"%%d\", req.%v)", gogen.CamelCase(name))
name := part[1 : len(part)-1]
parts := strings.Split(name, ".")
for idx, part := range parts {
parts[idx] = gogen.CamelCase(part)
}
camelName := strings.Join(parts, ".")

if _, ok := isEnum[camelName]; ok {
convert := fmt.Sprintf("fmt.Sprintf(\"%%d\", req.%v)", camelName)
rv = append(rv, convert)
continue
}
convert := fmt.Sprintf("fmt.Sprint(req.%v)", gogen.CamelCase(name))
convert := fmt.Sprintf("fmt.Sprint(req.%v)", camelName)
rv = append(rv, convert)
} else {
// Add quotes around things which'll be embeded as string literals,
Expand All @@ -284,7 +301,7 @@ if {{.LocalName}}StrArr, ok := {{.Location}}Params["{{.QueryParamName}}"]; ok {
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("Error while extracting {{.LocalName}} from {{.Location}}, {{.Location}}Params: %v", {{.Location}}Params))
}{{end}}
req.{{.CamelName}} = {{.TypeConversion}}
{{if or .Repeated .IsBaseType .IsEnum}}req.{{.CamelName}} = {{.TypeConversion}}{{end}}
`
mergedLogic := queryParamLogic + genericLogic + "}"
if f.Location == "path" {
Expand Down Expand Up @@ -382,9 +399,7 @@ func createDecodeConvertFunc(f Field) (string, bool) {
// pointer as well. So we special case args of a single custom message
// type so that the variable LocalName is declared as a pointer.
singleCustomTypeUnmarshalTmpl := `
var {{.LocalName}} *{{.GoType}}
{{.LocalName}} = &{{.GoType}}{}
err = json.Unmarshal([]byte({{.LocalName}}Str), {{.LocalName}})`
err = json.Unmarshal([]byte({{.LocalName}}Str), req.{{.CamelName}})`

errorCheckingTmpl := `
if err != nil {
Expand Down Expand Up @@ -490,6 +505,19 @@ func getZeroValue(f Field) string {
}
}

// getMuxPathTemplate translates gRPC Transcoding path into gorilla/mux
// compatible path template.
func getMuxPathTemplate(path string) string {
re := regexp.MustCompile(`{.+=.+}`)
stars := regexp.MustCompile(`\*{2,}`)
return re.ReplaceAllStringFunc(path, func(v string) string {
v = strings.Replace(v, "=", ":", 1)
v = stars.ReplaceAllLiteralString(v, `.+`)
v = strings.ReplaceAll(v, "*", `[^/]+`)
return v
})
}

// The 'basePath' of a path is the section from the start of the string till
// the first '{' character.
func basePath(path string) string {
Expand Down
93 changes: 93 additions & 0 deletions gengokit/httptransport/httptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,96 @@ func TestLowCamelName(t *testing.T) {
}
}
}

func Test_getMuxPathTemplate(t *testing.T) {
tests := []struct {
name string
path string
want string
}{
{
name: "no pattern",
path: "/v1/{parent}/books",
want: "/v1/{parent}/books",
},
{
name: "no *",
path: "/v1/{parent=shelves}/books",
want: "/v1/{parent:shelves}/books",
},
{
name: "single *",
path: "/v1/{parent=shelves/*}/books",
want: `/v1/{parent:shelves/[^/]+}/books`,
},
{
name: "multiple *",
path: "/v1/{name=shelves/*/books/*}",
want: `/v1/{name:shelves/[^/]+/books/[^/]+}`,
},
{
name: "**",
path: "/v1/shelves/{name=books/**}",
want: `/v1/shelves/{name:books/.+}`,
},
{
name: "mixed * and **",
path: "/v1/{name=shelves/*/books/**}",
want: `/v1/{name:shelves/[^/]+/books/.+}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getMuxPathTemplate(tt.path); got != tt.want {
t.Errorf("getMuxPathTemplate() = %v, want %v", got, tt.want)
}
})
}
}

func TestBinding_PathSections(t *testing.T) {
tests := []struct {
name string
pathTemplate string
want []string
}{
{
name: "simple",
pathTemplate: "/sum/{a}",
want: []string{
`""`,
`"sum"`,
"fmt.Sprint(req.A)",
},
},
{
name: "pattern",
pathTemplate: `/v1/{parent:shelves/[^/]+}/books`,
want: []string{
`""`,
`"v1"`,
"fmt.Sprint(req.Parent)",
`"books"`,
},
},
{
name: "dot notation",
pathTemplate: `/v1/{book.name:shelves/[^/]+/books/[^/]+}`,
want: []string{
`""`,
`"v1"`,
"fmt.Sprint(req.Book.Name)",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := &Binding{
PathTemplate: tt.pathTemplate,
}
if got := b.PathSections(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Binding.PathSections() = %v, want %v", got, tt.want)
}
})
}
}
2 changes: 1 addition & 1 deletion gengokit/httptransport/templates/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var ServerDecodeTemplate = `
}
}
pathParams := mux.Vars(r)
pathParams := encodePathParams(mux.Vars(r))
_ = pathParams
queryParams := r.URL.Query()
Expand Down
2 changes: 1 addition & 1 deletion gengokit/httptransport/templates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func DecodeHTTPSumZeroRequest(_ context.Context, r *http.Request) (interface{},
}
}
pathParams := mux.Vars(r)
pathParams := encodePathParams(mux.Vars(r))
_ = pathParams
queryParams := r.URL.Query()
Expand Down
36 changes: 18 additions & 18 deletions gengokit/template/template.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9bfafba

Please sign in to comment.