Skip to content

Commit

Permalink
Merge pull request #330 from lelandbatey/allow-enable-defaults-option
Browse files Browse the repository at this point in the history
Allow hooks.go to control server configuration e.g. EmitDefaults
  • Loading branch information
lelandbatey authored Jun 8, 2021
2 parents 9bfafba + d784398 commit 68ba2c3
Show file tree
Hide file tree
Showing 13 changed files with 710 additions and 49 deletions.
2 changes: 1 addition & 1 deletion cmd/_integration-tests/transport/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestMain(m *testing.M) {
}

// http test server
h := svc.MakeHTTPHandler(endpoints)
h := svc.MakeHTTPHandler(endpoints, svc.EncodeHTTPGenericResponse)
httpTestServer := httptest.NewServer(h)

// grpc test server
Expand Down
69 changes: 54 additions & 15 deletions gengokit/handlers/handlers.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Package handlers manages the exported methods in the service handler code
// adding/removing exported methods to match the service definition.
// Package handlers renders the Go source files found in <svcname>/handlers/.
// Most importantly, it handles rendering and modifying the
// <svcname>/handlers/handlers.go file, while making sure that existing code in
// that handlers.go file is not deleted.
package handlers

import (
Expand All @@ -11,8 +13,8 @@ import (
"io"
"strings"

log "github.com/sirupsen/logrus"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"

"github.com/metaverse/truss/gengokit"
"github.com/metaverse/truss/gengokit/handlers/templates"
Expand Down Expand Up @@ -47,8 +49,11 @@ func New(svc *svcdef.Service, prev io.Reader) (gengokit.Renderable, error) {
return &h, nil
}

// methodMap stores all defined service methods by name and is updated to
// remove service methods already in the handler file.
// methodMap stores all the service methods defined in the service.proto. It
// stores these methods by their string name. In order to not overwrite
// existing methods in the 'handlers/handlers.go' file, methods which already
// exist in the 'handlers/handlers.go' file will be removed from this
// methodMap.
type methodMap map[string]*svcdef.ServiceMethod

func newMethodMap(meths []*svcdef.ServiceMethod) methodMap {
Expand All @@ -63,20 +68,26 @@ type handler struct {
fset *token.FileSet
service *svcdef.Service
mMap methodMap
ast *ast.File
// The Abstract Syntax Tree (AST) of the existing go code found in
// 'handlers/handlers.go'. If the 'handlers/handlers.go' file does not
// exist, then ast will be nil.
ast *ast.File
}

type handlerData struct {
ServiceName string
Methods []*svcdef.ServiceMethod
}

// Render returns a go code server handler that has functions for all
// ServiceMethods in the service definition.
// Render returns an io.Reader with the go code of the server handler. That
// server handler ('handlers.go') has functions for all ServiceMethods in the
// service definition.
func (h *handler) Render(alias string, data *gengokit.Data) (io.Reader, error) {
if alias != ServerHandlerPath {
return nil, errors.Errorf("cannot render unknown file: %q", alias)
}
// implies that there is not an existing 'handlers/handlers.go' file and we
// can safely render the default template without worry.
if h.ast == nil {
return applyServerTempl(data)
}
Expand All @@ -90,7 +101,7 @@ func (h *handler) Render(alias string, data *gengokit.Data) (io.Reader, error) {
h.ast.Decls = h.mMap.pruneDecls(h.ast.Decls, strings.ToLower(data.Service.Name))
log.WithField("Service Methods", len(h.mMap)).Debug("After prune")

// create a new handlerData, and add all methods not defined in the previous file
// create a new handlerData, and add all methods not defined in the existing 'handlers/handlers.go' file
ex := handlerData{
ServiceName: data.Service.Name,
}
Expand Down Expand Up @@ -175,8 +186,15 @@ func (m methodMap) pruneDecls(decls []ast.Decl, svcName string) []ast.Decl {
return newDecls
}

// updateParams updates the second param of f to be `X`.(m.RequestType.Name).
// func ProtoMethod(ctx context.Context, *pb.Old) ...-> func ProtoMethod(ctx context.Context, *pb.(m.RequestType.Name))...
// updateParams updates the second param of f to be `X`.{m.RequestType.Name}.
// For example, this function signature:
//
// func ProtoMethod(ctx context.Context, *pb.Old)
//
// will become the following kind of function signature, where the old input type is
// replaced by the new input type defined in m.RequestType.Name:
//
// func ProtoMethod(ctx context.Context, *pb.{m.RequestType.Name})...
func updateParams(f *ast.FuncDecl, m *svcdef.ServiceMethod) {
if f.Type.Params.NumFields() != 2 {
log.WithField("Function", f.Name.Name).
Expand All @@ -186,8 +204,15 @@ func updateParams(f *ast.FuncDecl, m *svcdef.ServiceMethod) {
updatePBFieldType(f.Type.Params.List[1].Type, m.RequestType.Name)
}

// updateResults updates the first result of f to be `X`.(m.ResponseType.Name).
// func ProtoMethod(...) (*pb.Old, error) -> func ProtoMethod(...) (*pb.(m.ResponseType.Name), error)
// updateResults updates the first result of f to be `X`.{m.ResponseType.Name}.
// For example, this function signature:
//
// func ProtoMethod(...) (*pb.Old, error)
//
// will become the following function signature, where the prior return type is
// replaced with the return type defined in m.ResponseType.Name:
//
// func ProtoMethod(...) (*pb.{m.ResponseType.Name}, error)
func updateResults(f *ast.FuncDecl, m *svcdef.ServiceMethod) {
if f.Type.Results.NumFields() != 2 {
log.WithField("Function", f.Name.Name).
Expand All @@ -210,8 +235,22 @@ func updatePBFieldType(t ast.Expr, newType string) {
}
}

// isVaidFunc returns false if f is exported and does no exist in m with
// reciever svcName + "Service".
// isValidFunc indicates whether the function declaration here is a function
// declaration which is allowed to exist in handlers/handlers.go. The criteria
// for functions which are allowed in 'handlers/handlers.go' are any of the
// following:
//
// 1. The function is private
// 2. The function is a method of our server struct (e.g. fooStruct) AND
// it's also a method defined in the generated .pb.go server interface.
//
// These criteria are pretty strict, making many things invalid and thus will
// be removed. Some of the things which are invalid include:
//
// - Any public function which is not a method of the truss-created server
// struct is not valid and will be removed.
// - Any public method of the truss-created server which doesn't exist on
// the .pb.go server interface is not valid and will be removed.
func isValidFunc(f *ast.FuncDecl, m methodMap, svcName string) bool {
name := f.Name.String()
if !ast.IsExported(name) {
Expand Down
118 changes: 112 additions & 6 deletions gengokit/handlers/hooks.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package handlers

import (
"bytes"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"io"
"io/ioutil"
"strings"

"github.com/metaverse/truss/gengokit"
Expand All @@ -21,12 +27,112 @@ type HookRender struct {
prev io.Reader
}

// Render will return the existing file if it exists, otherwise it will return
// a brand new copy from the template.
func (h *HookRender) Render(_ string, _ *gengokit.Data) (io.Reader, error) {
// Render returns an io.Reader with the contents of
// <svcname>/handlers/hooks.go. If hooks.go does not already exist, then it's
// rendered anew from the templates defined in
// 'gengokit/handlers/templates/hook.go'. If hooks.go does exist already, then:
//
// 1. Modify the new code so that it will import
// "{{.ImportPath}}/svc/server" if it doesn't already.
// 2. Add the InterruptHandler if it doesn't exist already
// 3. Add the SetConfig function if it doesn't exist already
func (h *HookRender) Render(_ string, data *gengokit.Data) (io.Reader, error) {
if h.prev == nil {
return strings.NewReader(templates.Hook), nil
} else {
return h.prev, nil
return data.ApplyTemplate(templates.Hook+templates.HookInterruptHandler+templates.HookSetConfig, "HooksFullTemplate")
}
rawprev, err := ioutil.ReadAll(h.prev)
if err != nil {
return nil, err
}
code := bytes.NewBuffer(rawprev)

fset := token.NewFileSet()
past, err := parser.ParseFile(fset, "", code, parser.ParseComments)
if err != nil {
return nil, err
}
err = addServerImportIfNotPresent(past, data)
if err != nil {
return nil, err
}

var existingFuncs map[string]bool = map[string]bool{}
for _, d := range past.Decls {
switch x := d.(type) {
case *ast.FuncDecl:
name := x.Name.Name
existingFuncs[name] = true
}
}
code = bytes.NewBuffer(nil)
err = printer.Fprint(code, fset, past)
if err != nil {
return nil, err
}

// Both of these functions need to be in hooks.go in order for the service to start.
hookFuncs := map[string]string{
"InterruptHandler": templates.HookInterruptHandler,
"SetConfig": templates.HookSetConfig,
}

for name, f := range hookFuncs {
if _, ok := existingFuncs[name]; !ok {
code.ReadFrom(strings.NewReader(f))
}
}
return code, nil
}

// addServerImportIfNotPresent ensures that the hooks.go file imports the
// "{{.ImportPath -}} /svc/server" file since the SetConfig function requires
// that import in order to compile. It does this by mutating the handlerfile
// provided as parameter hf in place.
func addServerImportIfNotPresent(hf *ast.File, exec *gengokit.Data) error {
var imports *ast.GenDecl
for _, decl := range hf.Decls {
switch decl.(type) {
case *ast.GenDecl:
imports = decl.(*ast.GenDecl)
break
}
}

targetPathTmpl := `"{{.ImportPath -}} /svc"`
r, err := exec.ApplyTemplate(targetPathTmpl, "ServerPathTempl")
if err != nil {
return err
}
tmp, err := ioutil.ReadAll(r)
if err != nil {
return err
}

targetpath := string(tmp)

for _, spec := range imports.Specs {
switch spec.(type) {
case *ast.ImportSpec:
imp := spec.(*ast.ImportSpec)
if imp.Path.Value == targetpath {
return nil
}
}
}

nimp := ast.ImportSpec{
Doc: &ast.CommentGroup{
List: []*ast.Comment{
&ast.Comment{
Text: "// This Service",
},
},
},
Path: &ast.BasicLit{
Kind: token.STRING,
Value: targetpath,
},
}
imports.Specs = append(imports.Specs, &nimp)
return nil
}
107 changes: 107 additions & 0 deletions gengokit/handlers/hooks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package handlers

import (
"io"
"io/ioutil"
"strings"
"testing"

"github.com/metaverse/truss/gengokit"
"github.com/metaverse/truss/gengokit/httptransport"
"github.com/metaverse/truss/svcdef"

"github.com/pkg/errors"
"github.com/stretchr/testify/require"
)

func TestHooksAddingImport(t *testing.T) {
const def = `
syntax = "proto3";
package echo;
service _Foo_Bar {
rpc Echo (EchoRequest) returns (EchoResponse) {}
}
message EchoRequest {
string In = 1;
}
message EchoResponse {
string Out = 1;
}
`

const prev = `
package handlers
import (
"fmt"
"os"
"os/signal"
"syscall"
)
func InterruptHandler(errc chan<- error) {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
terminateError := fmt.Errorf("%s", <-c)
// Place whatever shutdown handling you want here
errc <- terminateError
}
`

sd, err := svcdef.NewFromString(def, gopath)
require.NoError(t, err)

conf := gengokit.Config{
GoPackage: "github.com/metaverse/truss/gengokit",
PBPackage: "github.com/metaverse/truss/gengokit/echo-service",
}

te, err := gengokit.NewData(sd, conf)
require.NoError(t, err)
newHooksf, err := renderHooksFile(prev, te)
require.NoError(t, err)

c1 := httptransport.FormatCode(prev)
c2 := httptransport.FormatCode(newHooksf)

require.Greater(t, len(c2), len(c1), "new code should be longer than the previous go code")
require.Contains(t, c2, "svc")
require.Contains(t, c2, "SetConfig")
require.Contains(t, c2, "InterruptHandler")
require.NotContains(t, c2, "server")

}

// renderHooksFile takes in a previous file as a string and returns the
// generated handlers/hooks.go file as a string. This helper method exists
// because the logic for reading the io.Reader to a string is repeated.
func renderHooksFile(prev string, data *gengokit.Data) (string, error) {
var prevFile io.Reader
if prev != "" {
prevFile = strings.NewReader(prev)
}

h := NewHook(prevFile)

next, err := h.Render(ServerHandlerPath, data)
if err != nil {
return "", err
}

nextBytes, err := ioutil.ReadAll(next)
if err != nil {
return "", err
}

nextCode, err := testFormat(string(nextBytes))
if err != nil {
return "", errors.Wrap(err, "cannot format")
}

nextCode = strings.TrimSpace(nextCode)

return nextCode, nil
}
Loading

0 comments on commit 68ba2c3

Please sign in to comment.