Skip to content

Commit

Permalink
fix: add graphql function to js script engine
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed May 19, 2021
1 parent e5ea6a9 commit b25d087
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 23 deletions.
4 changes: 2 additions & 2 deletions core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (c *gcontext) execQuery(qr queryReq, role string) (queryResp, error) {
}

if c.sc != nil && c.sc.RespFunc != nil {
res.data, err = c.scriptCallResp(res.data)
res.data, err = c.scriptCallResp(res.data, res.role)
}

return res, err
Expand Down Expand Up @@ -274,7 +274,7 @@ func (c *gcontext) resolveSQL(qr queryReq, role string) (queryResp, error) {
}

if c.sc != nil && c.sc.ReqFunc != nil {
qr.vars, err = c.scriptCallReq(qr.vars)
qr.vars, err = c.scriptCallReq(qr.vars, res.qc.st.role.Name)
if err != nil {
return res, err
}
Expand Down
7 changes: 4 additions & 3 deletions core/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"flag"
"fmt"
"os"
"testing"

"github.com/orlangure/gnomock"
Expand Down Expand Up @@ -113,8 +114,8 @@ func TestMain(m *testing.M) {
db.SetMaxIdleConns(100)
dbType = v.name

// if res := m.Run(); res != 0 {
// os.Exit(res)
// }
if res := m.Run(); res != 0 {
os.Exit(res)
}
}
}
44 changes: 43 additions & 1 deletion core/query1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1041,10 +1041,52 @@ func Example_queryWithScriptDirective() {
} else {
fmt.Println(string(res.Data))
}

// Output: {"usersbyid":{"email":"[email protected]","id":2}}
}

func Example_queryWithScriptDirectiveUsingGraphQL() {
gql := `query @script(name: "test.js") {
usersById(id: 2) {
id
email
}
}`

script := `
function response(json) {
let val = graphql('query { users(id: 1) { id email } }')
json.usersbyid.email = val.users.email
return json;
}
`

dir, err := ioutil.TempDir("", "test")
if err != nil {
panic(err)
}
defer os.RemoveAll(dir)

err = ioutil.WriteFile(path.Join(dir, "test.js"), []byte(script), 0644)
if err != nil {
panic(err)
}

conf := &core.Config{DBType: dbType, DisableAllowList: true, ScriptPath: dir}
gj, err := core.NewGraphJin(conf, db)
if err != nil {
panic(err)
}

res, err := gj.GraphQL(context.Background(), gql, nil, nil)
if err != nil {
fmt.Println(err)
} else {
fmt.Println(string(res.Data))
}

// Output: {"usersbyid":{"email":"[email protected]","id":2}}
}

func Example_queryWithView() {
gql := `query {
hot_products(limit: 3) {
Expand Down
132 changes: 115 additions & 17 deletions core/script.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (
"time"

"github.com/dop251/goja"
"github.com/dop251/goja/parser"
"github.com/dosco/graphjin/core/internal/qcode"
babel "github.com/jvatic/goja-babel"
)

type reqFunc func(map[string]interface{}) map[string]interface{}
type respFunc func(map[string]interface{}) map[string]interface{}
type reqFunc func(map[string]interface{}, string, interface{}) map[string]interface{}
type respFunc func(map[string]interface{}, string, interface{}) map[string]interface{}

func (gj *GraphJin) initScripting() error {
if err := babel.Init(5); err != nil {
Expand Down Expand Up @@ -47,7 +49,7 @@ func (c *gcontext) loadScript(name string) error {
return nil
}

func (c *gcontext) scriptCallReq(vars []byte) (_ []byte, err error) {
func (c *gcontext) scriptCallReq(vars []byte, role string) (_ []byte, err error) {
if c.sc.ReqFunc == nil {
return vars, nil
}
Expand All @@ -69,15 +71,24 @@ func (c *gcontext) scriptCallReq(vars []byte) (_ []byte, err error) {
}
}()

val := c.sc.ReqFunc(rj)
if err := c.sc.vm.Set("graphql", c.newGraphQLFunc(role)); err != nil {
return nil, err
}

var userID interface{}
if v := c.Value(UserIDKey); v == nil {
userID = v
}

val := c.sc.ReqFunc(rj, role, userID)
if val == nil {
return vars, nil
}

return json.Marshal(val)
}

func (c *gcontext) scriptCallResp(data []byte) (_ []byte, err error) {
func (c *gcontext) scriptCallResp(data []byte, role string) (_ []byte, err error) {
if c.sc.RespFunc == nil {
return data, nil
}
Expand All @@ -93,13 +104,22 @@ func (c *gcontext) scriptCallResp(data []byte) (_ []byte, err error) {
c.sc.vm.Interrupt("halt")
})

if err := c.sc.vm.Set("graphql", c.newGraphQLFunc(role)); err != nil {
return nil, err
}

var userID interface{}
if v := c.Value(UserIDKey); v == nil {
userID = v
}

defer func() {
if err1 := recover(); err1 != nil {
err = fmt.Errorf("script: %w", err1)
}
}()

val := c.sc.RespFunc(rj)
val := c.sc.RespFunc(rj, role, userID)
if val == nil {
return data, nil
}
Expand Down Expand Up @@ -136,48 +156,126 @@ func (c *gcontext) scriptInit(s *script, name string) error {
return err
}

s.vm = goja.New()
var vm *goja.Runtime

console := s.vm.NewObject()
console.Set("log", logFunc) //nolint: errcheck
if err := s.vm.Set("console", console); err != nil {
return err
if s.vm != nil {
s.vm.ClearInterrupt()

} else {
vm = goja.New()

vm.SetParserOptions(parser.WithDisableSourceMaps)

exports := vm.NewObject()
vm.Set("exports", exports) //nolint: errcheck

module := vm.NewObject()
_ = module.Set("exports", exports)
vm.Set("module", module) //nolint: errcheck

env := make(map[string]string, len(os.Environ()))
for _, e := range os.Environ() {
if strings.HasPrefix(e, "SG_") || strings.HasPrefix(e, "GJ_") {
continue
}
v := strings.SplitN(e, "=", 2)
env[v[0]] = v[1]
}
vm.Set("__ENV", env) //nolint: errcheck
vm.Set("global", vm.GlobalObject()) //nolint: errcheck

console := vm.NewObject()
console.Set("log", logFunc) //nolint: errcheck
vm.Set("console", console) //nolint: errcheck
}

time.AfterFunc(500*time.Millisecond, func() {
s.vm.Interrupt("halt")
vm.Interrupt("halt")
})

if _, err = s.vm.RunProgram(ast); err != nil {
if _, err = vm.RunProgram(ast); err != nil {
return err
}

req := s.vm.Get("request")
req := vm.Get("request")

if req != nil {
if _, ok := goja.AssertFunction(req); !ok {
return fmt.Errorf("script: function 'request' not found")
}

if err := s.vm.ExportTo(req, &s.ReqFunc); err != nil {
if err := vm.ExportTo(req, &s.ReqFunc); err != nil {
return err
}
}

resp := s.vm.Get("response")
resp := vm.Get("response")

if resp != nil {
if _, ok := goja.AssertFunction(resp); !ok {
return fmt.Errorf("script: function 'response' not found")
}

if err := s.vm.ExportTo(resp, &s.RespFunc); err != nil {
if err := vm.ExportTo(resp, &s.RespFunc); err != nil {
return err
}
}

s.vm = vm
return nil
}

func (c *gcontext) newGraphQLFunc(role string) func(string, map[string]interface{}, map[string]string) map[string]interface{} {

return func(
query string,
vars map[string]interface{},
opt map[string]string) map[string]interface{} {
var err error

op, name := qcode.GetQType(query)

qreq := queryReq{
op: op,
name: name,
query: []byte(query),
}

ct := gcontext{
Context: c.Context,
gj: c.gj,
op: c.op,
rc: c.rc,
}

if len(vars) != 0 {
if qreq.vars, err = json.Marshal(vars); err != nil {
panic(fmt.Errorf("variables: %s", err))
}
}

var r1 string

if v, ok := opt["role"]; ok && len(v) != 0 {
r1 = v
} else {
r1 = role
}

qres, err := ct.execQuery(qreq, r1)
if err != nil {
panic(err)
}

jres := make(map[string]interface{})
if err = json.Unmarshal(qres.data, &jres); err != nil {
panic(fmt.Errorf("json: %s", err))
}

return jres
}
}

func logFunc(args ...interface{}) {
for _, arg := range args {
if _, ok := arg.(map[string]interface{}); ok {
Expand Down

0 comments on commit b25d087

Please sign in to comment.