From b25d0870260a540f8bf3c7c6d014ec208556afc8 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Wed, 19 May 2021 01:52:33 -0700 Subject: [PATCH] fix: add graphql function to js script engine --- core/core.go | 4 +- core/core_test.go | 7 ++- core/query1_test.go | 44 ++++++++++++++- core/script.go | 132 ++++++++++++++++++++++++++++++++++++++------ 4 files changed, 164 insertions(+), 23 deletions(-) diff --git a/core/core.go b/core/core.go index 327052be..53c5b772 100644 --- a/core/core.go +++ b/core/core.go @@ -226,7 +226,7 @@ func (c *gcontext) execQuery(qr queryReq, role string) (queryResp, error) { } if c.sc != nil && c.sc.RespFunc != nil { - res.data, err = c.scriptCallResp(res.data) + res.data, err = c.scriptCallResp(res.data, res.role) } return res, err @@ -274,7 +274,7 @@ func (c *gcontext) resolveSQL(qr queryReq, role string) (queryResp, error) { } if c.sc != nil && c.sc.ReqFunc != nil { - qr.vars, err = c.scriptCallReq(qr.vars) + qr.vars, err = c.scriptCallReq(qr.vars, res.qc.st.role.Name) if err != nil { return res, err } diff --git a/core/core_test.go b/core/core_test.go index 88f7ed04..7cb7b780 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "flag" "fmt" + "os" "testing" "github.com/orlangure/gnomock" @@ -113,8 +114,8 @@ func TestMain(m *testing.M) { db.SetMaxIdleConns(100) dbType = v.name - // if res := m.Run(); res != 0 { - // os.Exit(res) - // } + if res := m.Run(); res != 0 { + os.Exit(res) + } } } diff --git a/core/query1_test.go b/core/query1_test.go index 2a1c106b..d8c8f486 100644 --- a/core/query1_test.go +++ b/core/query1_test.go @@ -1041,10 +1041,52 @@ func Example_queryWithScriptDirective() { } else { fmt.Println(string(res.Data)) } - // Output: {"usersbyid":{"email":"u...@test.com","id":2}} } +func Example_queryWithScriptDirectiveUsingGraphQL() { + gql := `query @script(name: "test.js") { + usersById(id: 2) { + id + email + } + }` + + script := ` + function response(json) { + let val = graphql('query { users(id: 1) { id email } }') + json.usersbyid.email = val.users.email + return json; + } + ` + + dir, err := ioutil.TempDir("", "test") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + + err = ioutil.WriteFile(path.Join(dir, "test.js"), []byte(script), 0644) + if err != nil { + panic(err) + } + + conf := &core.Config{DBType: dbType, DisableAllowList: true, ScriptPath: dir} + gj, err := core.NewGraphJin(conf, db) + if err != nil { + panic(err) + } + + res, err := gj.GraphQL(context.Background(), gql, nil, nil) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(string(res.Data)) + } + + // Output: {"usersbyid":{"email":"user1@test.com","id":2}} +} + func Example_queryWithView() { gql := `query { hot_products(limit: 3) { diff --git a/core/script.go b/core/script.go index 5adeb8b9..4f867658 100644 --- a/core/script.go +++ b/core/script.go @@ -11,11 +11,13 @@ import ( "time" "github.com/dop251/goja" + "github.com/dop251/goja/parser" + "github.com/dosco/graphjin/core/internal/qcode" babel "github.com/jvatic/goja-babel" ) -type reqFunc func(map[string]interface{}) map[string]interface{} -type respFunc func(map[string]interface{}) map[string]interface{} +type reqFunc func(map[string]interface{}, string, interface{}) map[string]interface{} +type respFunc func(map[string]interface{}, string, interface{}) map[string]interface{} func (gj *GraphJin) initScripting() error { if err := babel.Init(5); err != nil { @@ -47,7 +49,7 @@ func (c *gcontext) loadScript(name string) error { return nil } -func (c *gcontext) scriptCallReq(vars []byte) (_ []byte, err error) { +func (c *gcontext) scriptCallReq(vars []byte, role string) (_ []byte, err error) { if c.sc.ReqFunc == nil { return vars, nil } @@ -69,7 +71,16 @@ func (c *gcontext) scriptCallReq(vars []byte) (_ []byte, err error) { } }() - val := c.sc.ReqFunc(rj) + if err := c.sc.vm.Set("graphql", c.newGraphQLFunc(role)); err != nil { + return nil, err + } + + var userID interface{} + if v := c.Value(UserIDKey); v == nil { + userID = v + } + + val := c.sc.ReqFunc(rj, role, userID) if val == nil { return vars, nil } @@ -77,7 +88,7 @@ func (c *gcontext) scriptCallReq(vars []byte) (_ []byte, err error) { return json.Marshal(val) } -func (c *gcontext) scriptCallResp(data []byte) (_ []byte, err error) { +func (c *gcontext) scriptCallResp(data []byte, role string) (_ []byte, err error) { if c.sc.RespFunc == nil { return data, nil } @@ -93,13 +104,22 @@ func (c *gcontext) scriptCallResp(data []byte) (_ []byte, err error) { c.sc.vm.Interrupt("halt") }) + if err := c.sc.vm.Set("graphql", c.newGraphQLFunc(role)); err != nil { + return nil, err + } + + var userID interface{} + if v := c.Value(UserIDKey); v == nil { + userID = v + } + defer func() { if err1 := recover(); err1 != nil { err = fmt.Errorf("script: %w", err1) } }() - val := c.sc.RespFunc(rj) + val := c.sc.RespFunc(rj, role, userID) if val == nil { return data, nil } @@ -136,48 +156,126 @@ func (c *gcontext) scriptInit(s *script, name string) error { return err } - s.vm = goja.New() + var vm *goja.Runtime - console := s.vm.NewObject() - console.Set("log", logFunc) //nolint: errcheck - if err := s.vm.Set("console", console); err != nil { - return err + if s.vm != nil { + s.vm.ClearInterrupt() + + } else { + vm = goja.New() + + vm.SetParserOptions(parser.WithDisableSourceMaps) + + exports := vm.NewObject() + vm.Set("exports", exports) //nolint: errcheck + + module := vm.NewObject() + _ = module.Set("exports", exports) + vm.Set("module", module) //nolint: errcheck + + env := make(map[string]string, len(os.Environ())) + for _, e := range os.Environ() { + if strings.HasPrefix(e, "SG_") || strings.HasPrefix(e, "GJ_") { + continue + } + v := strings.SplitN(e, "=", 2) + env[v[0]] = v[1] + } + vm.Set("__ENV", env) //nolint: errcheck + vm.Set("global", vm.GlobalObject()) //nolint: errcheck + + console := vm.NewObject() + console.Set("log", logFunc) //nolint: errcheck + vm.Set("console", console) //nolint: errcheck } time.AfterFunc(500*time.Millisecond, func() { - s.vm.Interrupt("halt") + vm.Interrupt("halt") }) - if _, err = s.vm.RunProgram(ast); err != nil { + if _, err = vm.RunProgram(ast); err != nil { return err } - req := s.vm.Get("request") + req := vm.Get("request") if req != nil { if _, ok := goja.AssertFunction(req); !ok { return fmt.Errorf("script: function 'request' not found") } - if err := s.vm.ExportTo(req, &s.ReqFunc); err != nil { + if err := vm.ExportTo(req, &s.ReqFunc); err != nil { return err } } - resp := s.vm.Get("response") + resp := vm.Get("response") if resp != nil { if _, ok := goja.AssertFunction(resp); !ok { return fmt.Errorf("script: function 'response' not found") } - if err := s.vm.ExportTo(resp, &s.RespFunc); err != nil { + if err := vm.ExportTo(resp, &s.RespFunc); err != nil { return err } } + + s.vm = vm return nil } +func (c *gcontext) newGraphQLFunc(role string) func(string, map[string]interface{}, map[string]string) map[string]interface{} { + + return func( + query string, + vars map[string]interface{}, + opt map[string]string) map[string]interface{} { + var err error + + op, name := qcode.GetQType(query) + + qreq := queryReq{ + op: op, + name: name, + query: []byte(query), + } + + ct := gcontext{ + Context: c.Context, + gj: c.gj, + op: c.op, + rc: c.rc, + } + + if len(vars) != 0 { + if qreq.vars, err = json.Marshal(vars); err != nil { + panic(fmt.Errorf("variables: %s", err)) + } + } + + var r1 string + + if v, ok := opt["role"]; ok && len(v) != 0 { + r1 = v + } else { + r1 = role + } + + qres, err := ct.execQuery(qreq, r1) + if err != nil { + panic(err) + } + + jres := make(map[string]interface{}) + if err = json.Unmarshal(qres.data, &jres); err != nil { + panic(fmt.Errorf("json: %s", err)) + } + + return jres + } +} + func logFunc(args ...interface{}) { for _, arg := range args { if _, ok := arg.(map[string]interface{}); ok {