diff --git a/core/api.go b/core/api.go index 7956b5dd..189a5e05 100644 --- a/core/api.go +++ b/core/api.go @@ -1,46 +1,5 @@ // Package core provides an API to include and use the GraphJin compiler with your own code. // For detailed documentation visit https://graphjin.com -// -// Example usage: -/* - package main - - import ( - "database/sql" - "fmt" - "time" - "github.com/dosco/graphjin/v2/core" - _ "github.com/jackc/pgx/v5/stdlib" - ) - - func main() { - db, err := sql.Open("pgx", "postgres://postgrs:@localhost:5432/example_db") - if err != nil { - log.Fatal(err) - } - - gj, err := core.NewGraphJin(nil, db) - if err != nil { - log.Fatal(err) - } - - query := ` - query { - posts { - id - title - } - }` - - ctx = context.WithValue(ctx, core.UserIDKey, 1) - - res, err := gj.GraphQL(ctx, query, nil) - if err != nil { - log.Fatal(err) - } - - } -*/ package core import ( @@ -121,10 +80,6 @@ type GraphJin struct { type Option func(*graphjin) error -var ( - errPersistedQueryNotFound = errors.New("persisted query not found") -) - // NewGraphJin creates the GraphJin struct, this involves querying the database to learn its // schemas and relationships func NewGraphJin(conf *Config, db *sql.DB, options ...Option) (g *GraphJin, err error) { @@ -171,7 +126,7 @@ func newGraphJin(conf *Config, options ...Option) (*graphjin, error) { if conf == nil { - conf = &Config{Debug: true, DisableAllowList: true} + conf = &Config{Debug: true} } t := time.Now() @@ -287,7 +242,6 @@ type Result struct { sql string role string cacheControl string - actionJSON json.RawMessage Errors []Error `json:"errors,omitempty"` Vars json.RawMessage `json:"-"` Data json.RawMessage `json:"data,omitempty"` @@ -324,57 +278,75 @@ func (rc *ReqConfig) GetNamespace() (string, bool) { // // In developer mode all named queries are saved into the queries folder and in production mode only // queries from these saved queries can be used -func (g *GraphJin) GraphQL( - c context.Context, +func (g *GraphJin) GraphQL(c context.Context, query string, vars json.RawMessage, - rc *ReqConfig) (*Result, error) { + rc *ReqConfig) (res *Result, err error) { gj := g.Load().(*graphjin) - ns := gj.namespace c1, span := gj.spanStart(c, "GraphJin Query") defer span.End() - if rc != nil { - if rc.ns != nil { - ns = *rc.ns - } - if rc.APQKey != "" && query == "" { - if v, ok := gj.apq.Get(ns, rc.APQKey); ok { - query = v.query - } else { - return nil, errPersistedQueryNotFound - } + var queryBytes []byte + var inCache bool + + // get query from apq cache if apq key exists + if rc != nil && rc.APQKey != "" { + queryBytes, inCache = gj.apq.Get(rc.APQKey) + } + + // query not found in apq cache so use original query + if len(queryBytes) == 0 { + queryBytes = []byte(query) + } + + // fast extract name and query type from query + var h graph.FPInfo + if h, err = graph.FastParseBytes(queryBytes); err != nil { + return + } + r := gj.newGraphqlReq(rc, h.Operation, h.Name, queryBytes, vars) + + // if production then get query and metadata from allow list + if gj.prod { + var item allow.Item + item, err = gj.allowList.GetByName(h.Name, gj.prod) + if err != nil { + err = fmt.Errorf("%w: %s", err, h.Name) + return } + r.Set(item) } - res, err := gj.graphQL(c1, query, vars, rc) - if err != nil { - return res, err + // do the query + var resp graphqlResp + if resp, err = gj.query(c1, r); err != nil { + return } + res = resp.res - if rc != nil && rc.APQKey != "" { - gj.apq.Set(ns, rc.APQKey, apqInfo{query: query}) + // save to apq cache is apq key exists and not already in cache + if !inCache && rc != nil && rc.APQKey != "" { + gj.apq.Set(rc.APQKey, r.query) } + // if not production then save to allow list if !gj.prod { - err := gj.saveToAllowList(res.actionJSON, query, res.ns) - if err != nil { - return res, err + if err = gj.saveToAllowList(resp.qc, vars, resp.res.ns); err != nil { + return } } - return res, err + return } // GraphQLByName is similar to the GraphQL function except that queries saved // in the queries folder can directly be used by their filename. -func (g *GraphJin) GraphQLByName( - c context.Context, +func (g *GraphJin) GraphQLByName(c context.Context, name string, vars json.RawMessage, - rc *ReqConfig) (*Result, error) { + rc *ReqConfig) (res *Result, err error) { gj := g.Load().(*graphjin) @@ -383,83 +355,89 @@ func (g *GraphJin) GraphQLByName( item, err := gj.allowList.GetByName(name, gj.prod) if err != nil { - return nil, err + err = fmt.Errorf("%w: %s", err, name) + return } - op := qcode.GetQTypeByName(item.Operation) - query := item.Query - return gj.graphQLWithOpName(c1, op, name, query, vars, rc) + r := gj.newGraphqlReq(rc, "", name, nil, vars) + r.Set(item) + + res, err = gj.queryWithResult(c1, r) + return } -func (gj *graphjin) graphQL( - c context.Context, - query string, - vars json.RawMessage, - rc *ReqConfig) (*Result, error) { +type graphqlReq struct { + ns string + op qcode.QType + name string + query []byte + vars json.RawMessage + aschema json.RawMessage + rc *ReqConfig +} - var op qcode.QType - var name string +type graphqlResp struct { + res *Result + qc *qcode.QCode +} - if h, err := graph.FastParse(query); err == nil { - name = h.Name - op = qcode.GetQTypeByName(h.Operation) - } else { - return nil, err +func (gj *graphjin) newGraphqlReq(rc *ReqConfig, + op string, + name string, + query []byte, + vars json.RawMessage) (r graphqlReq) { + + r = graphqlReq{ + op: qcode.GetQTypeByName(op), + name: name, + query: query, + vars: vars, } - if gj.prod && !gj.conf.DisableAllowList { - item, err := gj.allowList.GetByName(name, gj.prod) - if err != nil { - return nil, err - } - op = qcode.GetQTypeByName(item.Operation) - query = item.Query + if rc != nil && rc.ns != nil { + r.ns = *rc.ns + } else { + r.ns = gj.namespace } - return gj.graphQLWithOpName(c, op, name, query, vars, rc) + return } -func (gj *graphjin) graphQLWithOpName( - c context.Context, - op qcode.QType, - name string, - query string, - vars json.RawMessage, - rc *ReqConfig) (*Result, error) { +func (r *graphqlReq) Set(item allow.Item) { + r.ns = item.Namespace + r.op = qcode.GetQTypeByName(item.Operation) + r.name = item.Name + r.query = item.Query + r.aschema = item.Vars +} - ns := gj.namespace - if rc != nil && rc.ns != nil { - ns = *rc.ns - } +func (gj *graphjin) queryWithResult(c context.Context, r graphqlReq) ( + res *Result, err error) { + resp, err := gj.query(c, r) + return resp.res, err +} - ct := &gcontext{ - gj: gj, - rc: rc, - ns: ns, - op: op, - name: name, - } +func (gj *graphjin) query(c context.Context, r graphqlReq) ( + resp graphqlResp, err error) { - res := &Result{ - ns: ns, - op: op, - name: name, + resp.res = &Result{ + ns: r.ns, + op: r.op, + name: r.name, } - if !gj.prod && name == "IntrospectionQuery" { - v, err := gj.introspection(query) - if err != nil { - return res, err - } - res.Data = v - return res, nil + if !gj.prod && r.name == "IntrospectionQuery" { + resp.res.Data, err = gj.introspection(r.query) + return } - if ct.op == qcode.QTSubscription { - return res, errors.New("use 'core.Subscribe' for subscriptions") + if r.op == qcode.QTSubscription { + err = errors.New("use 'core.Subscribe' for subscriptions") + return } - if ct.op == qcode.QTMutation && gj.schema.DBType() == "mysql" { - return res, errors.New("mysql: mutations not supported") + if r.op == qcode.QTMutation && gj.schema.DBType() == "mysql" { + err = errors.New("mysql: mutations not supported") + return } var role string @@ -475,29 +453,22 @@ func (gj *graphjin) graphQLWithOpName( } } - qr := queryReq{ - ns: ct.ns, - op: ct.op, - name: ct.name, - query: []byte(query), - vars: vars, - } + s := newGState(gj, r, role) - qres, err := ct.execQuery(c, qr, role) + err = s.compileAndExecuteWrapper(c) if err != nil { - res.Errors = []Error{{Message: err.Error()}} + resp.res.Errors = []Error{{Message: err.Error()}} } - res.actionJSON = qres.actionVar() - res.sql = qres.sql() - res.cacheControl = qres.cacheHeader() + resp.qc = s.qcode() + resp.res.sql = s.sql() + resp.res.cacheControl = s.cacheHeader() - res.Data = json.RawMessage(qres.data) - res.Hash = qres.dhash - res.role = qres.role - res.Vars = vars - - return res, err + resp.res.Vars = r.vars + resp.res.Data = json.RawMessage(s.data) + resp.res.Hash = s.dhash + resp.res.role = s.role + return } // Reload redoes database discover and reinitializes GraphJin. @@ -516,15 +487,6 @@ func (g *GraphJin) IsProd() bool { return gj.prod } -func Upgrade(configPath string) error { - fs := fs.NewOsFSWithBase(configPath) - al, err := allow.New(nil, fs, false) - if err != nil { - return fmt.Errorf("failed to initialize allow list: %w", err) - } - return al.Upgrade() -} - type Header struct { Type OpType Name string diff --git a/core/apq.go b/core/apq.go index 03b1a132..2c7124bc 100644 --- a/core/apq.go +++ b/core/apq.go @@ -4,10 +4,6 @@ import ( lru "github.com/hashicorp/golang-lru" ) -type apqInfo struct { - query string -} - type apqCache struct { cache *lru.TwoQueueCache } @@ -17,13 +13,14 @@ func (gj *graphjin) initAPQCache() (err error) { return } -func (c apqCache) Get(ns, key string) (info apqInfo, fromCache bool) { - if v, ok := c.cache.Get((ns + key)); ok { - return v.(apqInfo), true +func (c apqCache) Get(key string) (val []byte, fromCache bool) { + if v, ok := c.cache.Get(key); ok { + val = v.([]byte) + fromCache = true } return } -func (c apqCache) Set(ns, key string, val apqInfo) { - c.cache.Add((ns + key), val) +func (c apqCache) Set(key string, val []byte) { + c.cache.Add(key, val) } diff --git a/core/args.go b/core/args.go index 458b8cc2..72a6cc96 100644 --- a/core/args.go +++ b/core/args.go @@ -21,7 +21,6 @@ type args struct { func (gj *graphjin) argList(c context.Context, md psql.Metadata, vars []byte, - pf []byte, rc *ReqConfig) (args, error) { ar := args{cindx: -1} diff --git a/core/build.go b/core/build.go deleted file mode 100644 index ad1dc5ec..00000000 --- a/core/build.go +++ /dev/null @@ -1,133 +0,0 @@ -package core - -import ( - "bytes" - "encoding/json" - "fmt" - "sync" - - "github.com/dosco/graphjin/v2/core/internal/psql" - "github.com/dosco/graphjin/v2/core/internal/qcode" - "github.com/dosco/graphjin/v2/core/internal/valid" -) - -type queryComp struct { - sync.Once - qr queryReq - st stmt -} - -type stmt struct { - role *Role - qc *qcode.QCode - md psql.Metadata - va *valid.Validate - sql string -} - -func (gj *graphjin) compileQuery(qr queryReq, role string) (*queryComp, error) { - var err error - qcomp := &queryComp{qr: qr} - - if !gj.prod || gj.conf.DisableAllowList { - userVars := make(map[string]json.RawMessage) - - if len(qr.vars) != 0 { - if err := json.Unmarshal(qr.vars, &userVars); err != nil { - return nil, fmt.Errorf("variables: %w", err) - } - } - - qcomp.st, err = gj.compileQueryForRole(qr, userVars, role) - if err != nil { - return nil, err - } - - } else { - // In production mode enforce the allow list and - // compile and cache the result else compile each time - // the allowlist queries are already loaded at init. - // if qcomp, err = gj.getQuery(qr, role); err != nil { - // return nil, err - // } - if qcomp, err = gj.compileQueryForRoleOnce(qcomp, role); err != nil { - return nil, err - } - - // Overwrite allow list vars with user vars - qcomp.qr.vars = qr.vars - qcomp.qr.ns = qr.ns - } - return qcomp, err -} - -func (gj *graphjin) compileQueryForRoleOnce(qcomp *queryComp, role string) (*queryComp, error) { - var err error - - qr := qcomp.qr - val, loaded := gj.queries.LoadOrStore((qr.ns + qr.name + role), qcomp) - if loaded { - return val.(*queryComp), nil - } - - qcomp.Do(func() { - var vars1 map[string]json.RawMessage - - if len(qcomp.qr.vars) != 0 { - err = json.Unmarshal(qcomp.qr.vars, &vars1) - } - - if err == nil { - qcomp.st, err = gj.compileQueryForRole(qcomp.qr, vars1, role) - } - }) - if err != nil { - return nil, err - } - return qcomp, nil -} - -func (gj *graphjin) compileQueryForRole( - qr queryReq, vm map[string]json.RawMessage, role string) (stmt, error) { - - var st stmt - var err error - var ok bool - - if st.role, ok = gj.roles[role]; !ok { - return st, fmt.Errorf(`roles '%s' not defined in c.gj.config`, role) - } - - if st.qc, err = gj.qc.Compile(qr.query, vm, st.role.Name, qr.ns); err != nil { - return st, err - } - - var w bytes.Buffer - - if st.md, err = gj.pc.Compile(&w, st.qc); err != nil { - return st, err - } - - if st.qc.Validation.Source != "" { - vc, ok := gj.validatorMap[st.qc.Validation.Type] - if !ok { - return st, fmt.Errorf("no validator found for '%s'", st.qc.Validation.Type) - } - ve, err := vc.CompileValidation(st.qc.Validation.Source) - if err != nil { - return st, err - } - st.qc.Validation.VE = ve - st.qc.Validation.Exists = true - } - - if st.qc.Script.Name != "" { - if err := gj.loadScript(st.qc); err != nil { - return st, err - } - } - - st.va = valid.New() - st.sql = w.String() - return st, nil -} diff --git a/core/config.go b/core/config.go index 954f8d89..da438deb 100644 --- a/core/config.go +++ b/core/config.go @@ -20,10 +20,14 @@ type Config struct { // Is used to encrypt opaque values such as the cursor. Auto-generated when not set SecretKey string `mapstructure:"secret_key" json:"secret_key" yaml:"secret_key" jsonschema:"title=Secret Key"` - // When set to true it disables the allow list workflow and all queries are - // always compiled even in production (Warning possible security concern) + // When set to true it disables the allow list workflow DisableAllowList bool `mapstructure:"disable_allow_list" json:"disable_allow_list" yaml:"disable_allow_list" jsonschema:"title=Disable Allow List,default=false"` + // When set to true a database schema file will be generated in dev mode and + // used in production mode. Auto database discovery will be disabled + // in production mode. + EnableSchema bool `mapstructure:"enable_schema" json:"enable_schema" yaml:"enable_schema" jsonschema:"title=Enable Schema,default=false"` + // Forces the database session variable 'user.id' to be set to the user id SetUserID bool `mapstructure:"set_user_id" json:"set_user_id" yaml:"set_user_id" jsonschema:"title=Set User ID,default=false"` diff --git a/core/core.go b/core/core.go index 6d9dc95a..5ca298f8 100644 --- a/core/core.go +++ b/core/core.go @@ -1,19 +1,20 @@ package core import ( + "bytes" "context" - "crypto/sha256" "database/sql" "database/sql/driver" "encoding/json" "errors" "fmt" - "strconv" "github.com/avast/retry-go" + "github.com/dosco/graphjin/v2/core/internal/allow" "github.com/dosco/graphjin/v2/core/internal/psql" "github.com/dosco/graphjin/v2/core/internal/qcode" "github.com/dosco/graphjin/v2/core/internal/sdata" + "github.com/dosco/graphjin/v2/internal/jsn" ) var decPrefix = []byte(`__gj/enc:`) @@ -53,21 +54,6 @@ const ( // Duration time.Duration `json:"duration"` // } -type gcontext struct { - gj *graphjin - op qcode.QType - rc *ReqConfig - ns string - name string -} - -type queryResp struct { - qc *queryComp - role string - data []byte - dhash [sha256.Size]byte -} - func (gj *graphjin) initDiscover() error { switch gj.conf.DBType { case "": @@ -81,22 +67,54 @@ func (gj *graphjin) initDiscover() error { if err := gj._initDiscover(); err != nil { return fmt.Errorf("%s: %w", gj.dbtype, err) } + return nil } -func (gj *graphjin) _initDiscover() error { - var err error +func (gj *graphjin) _initDiscover() (err error) { + if gj.prod && gj.conf.EnableSchema { + b, err := gj.fs.ReadFile("db.schema") + if err != nil { + return err + } + ds, err := qcode.ParseSchema(b) + if err != nil { + return err + } + gj.dbinfo = sdata.NewDBInfo(ds.Type, + ds.Version, + ds.Schema, + "", + ds.Columns, + ds.Functions, + gj.conf.Blocklist) + } // If gj.dbinfo is not null then it's probably set - // for tests - if gj.dbinfo == nil { - gj.dbinfo, err = sdata.GetDBInfo( - gj.db, - gj.dbtype, - gj.conf.Blocklist) + // for tests or the schema file is being used + if gj.dbinfo != nil { + return } - return err + gj.dbinfo, err = sdata.GetDBInfo( + gj.db, + gj.dbtype, + gj.conf.Blocklist) + if err != nil { + return + } + + if !gj.prod && gj.conf.EnableSchema { + var buf bytes.Buffer + if err := writeSchema(gj.dbinfo, &buf); err != nil { + return err + } + err = gj.fs.CreateFile("db.schema", buf.Bytes()) + if err != nil { + return + } + } + return } func (gj *graphjin) initSchema() error { @@ -179,269 +197,54 @@ func (gj *graphjin) initCompilers() error { return nil } -func (gj *graphjin) executeRoleQuery(ctx context.Context, +func (gj *graphjin) executeRoleQuery(c context.Context, conn *sql.Conn, - vars []byte, - pf []byte, - rc *ReqConfig) (string, error) { + vars json.RawMessage, + rc *ReqConfig) (role string, err error) { - var role string - var ar args - var err error - - md := gj.roleStmtMD - - if ctx.Value(UserIDKey) == nil { - return "anon", nil + if c.Value(UserIDKey) == nil { + role = "anon" + return } - if ar, err = gj.argList(ctx, md, vars, pf, rc); err != nil { - return "", err + var ar args + if ar, err = gj.argList(c, + gj.roleStmtMD, + vars, + rc); err != nil { + return } if conn == nil { - ctx1, span := gj.spanStart(ctx, "Get Connection") - err = retryOperation(ctx1, func() error { - conn, err = gj.db.Conn(ctx1) - return err + c1, span := gj.spanStart(c, "Get Connection") + defer span.End() + + err = retryOperation(c1, func() (err1 error) { + conn, err1 = gj.db.Conn(c1) + return }) if err != nil { span.Error(err) - } - span.End() - - if err != nil { - return role, err + return } defer conn.Close() } - ctx1, span := gj.spanStart(ctx, "Execute Role Query") + c1, span := gj.spanStart(c, "Execute Role Query") defer span.End() - err = retryOperation(ctx1, func() error { + err = retryOperation(c1, func() (err1 error) { return conn. - QueryRowContext(ctx1, gj.roleStmt, ar.values...). + QueryRowContext(c1, gj.roleStmt, ar.values...). Scan(&role) }) - if err != nil { span.Error(err) - return role, err + return } span.SetAttributesString(stringAttr{"role", role}) - return role, err -} - -func (c *gcontext) execQuery(ctx context.Context, qr queryReq, role string) (queryResp, error) { - var res queryResp - var err error - - if res, err = c.resolveSQL(ctx, qr, role); err != nil { - return res, err - } - - if c.gj.conf.Debug { - c.debugLog(&res.qc.st) - } - - qc := res.qc.st.qc - - if len(res.data) == 0 { - return res, nil - } - - if qc.Remotes != 0 { - if res, err = c.execRemoteJoin(ctx, res); err != nil { - return res, err - } - } - - if qc.Script.Exists && qc.Script.HasRespFn() { - res.data, err = c.scriptCallResp(ctx, qc, res.data, res.role) - } - - return res, err -} - -func (c *gcontext) resolveSQL(ctx context.Context, qr queryReq, role string) (queryResp, error) { - var conn *sql.Conn - var err error - - res := queryResp{role: role} - - ctx1, span := c.gj.spanStart(ctx, "Get Connection") - err = retryOperation(ctx1, func() error { - conn, err = c.gj.db.Conn(ctx1) - return err - }) - if err != nil { - span.Error(err) - } - span.End() - - if err != nil { - return res, err - } - defer conn.Close() - - if c.gj.conf.SetUserID { - ctx1, span = c.gj.spanStart(ctx, "Set Local User ID") - err = retryOperation(ctx1, func() error { - return c.setLocalUserID(ctx1, conn) - }) - if err != nil { - span.Error(err) - } - span.End() - - if err != nil { - return res, err - } - } - - if v := ctx.Value(UserRoleKey); v != nil { - res.role = v.(string) - } else if c.gj.abacEnabled { - res.role, err = c.gj.executeRoleQuery(ctx, conn, qr.vars, c.gj.pf, c.rc) - } - - if err != nil { - return res, err - } - - qcomp, err := c.gj.compileQuery(qr, res.role) - if err != nil { - return res, err - } - res.qc = qcomp - - return c.resolveCompiledQuery(ctx, conn, qcomp, res) -} - -func (c *gcontext) resolveCompiledQuery( - ctx context.Context, - conn *sql.Conn, - qcomp *queryComp, - res queryResp) ( - queryResp, error) { - - // From here on use qcomp. for everything including accessing qr since it contains updated values of the latter. This code needs some refactoring - - if err := c.validateAndUpdateVars(ctx, qcomp, &res); err != nil { - return res, err - } - - args, err := c.gj.argList(ctx, qcomp.st.md, qcomp.qr.vars, c.gj.pf, c.rc) - if err != nil { - return res, err - } - - ctx1, span := c.gj.spanStart(ctx, "Execute Query") - defer span.End() - - err = retryOperation(ctx1, func() error { - return conn. - QueryRowContext(ctx1, qcomp.st.sql, args.values...). - Scan(&res.data) - }) - - if err != nil && err != sql.ErrNoRows { - span.Error(err) - } - - if span.IsRecording() { - span.SetAttributesString( - stringAttr{"query.namespace", res.qc.qr.ns}, - stringAttr{"query.operation", qcomp.st.qc.Type.String()}, - stringAttr{"query.name", qcomp.st.qc.Name}, - stringAttr{"query.role", qcomp.st.role.Name}) - } - - if err == sql.ErrNoRows { - return res, nil - } else if err != nil { - return res, err - } - - res.dhash = sha256.Sum256(res.data) - - res.data, err = encryptValues(res.data, - c.gj.pf, decPrefix, res.dhash[:], c.gj.encKey) - if err != nil { - return res, err - } - - return res, nil -} - -func (c *gcontext) validateAndUpdateVars(ctx context.Context, qcomp *queryComp, res *queryResp) error { - var vars map[string]interface{} - - qc := qcomp.st.qc - qr := qcomp.qr - - if qc == nil { - return nil - } - - if len(qr.vars) != 0 || qc.Script.Name != "" { - vars = make(map[string]interface{}) - } - - if len(qr.vars) != 0 { - if err := json.Unmarshal(qr.vars, &vars); err != nil { - return err - } - } - - if qc.Validation.Exists { - if err := qc.Validation.VE.Validate(qr.vars); err != nil { - return err - } - } - - if qc.Consts != nil { - errs := qcomp.st.va.ValidateMap(ctx, vars, qc.Consts) - if !c.gj.prod && len(errs) != 0 { - for k, v := range errs { - c.gj.log.Printf("validation failed: $%s: %s", k, v.Error()) - } - } - - if len(errs) != 0 { - return errors.New("validation failed") - } - } - - if qc.Script.Exists && qc.Script.HasReqFn() { - v, err := c.scriptCallReq(ctx, qc, vars, qcomp.st.role.Name) - if len(v) != 0 { - qcomp.qr.vars = v - } else if err != nil { - return err - } - } - return nil -} - -func (c *gcontext) setLocalUserID(ctx context.Context, conn *sql.Conn) error { - var err error - - if v := ctx.Value(UserIDKey); v == nil { - return nil - } else { - switch v1 := v.(type) { - case string: - _, err = conn.ExecContext(ctx, `SET SESSION "user.id" = '`+v1+`'`) - - case int: - _, err = conn.ExecContext(ctx, `SET SESSION "user.id" = `+strconv.Itoa(v1)) - } - } - - return err + return } func (r *Result) Operation() OpType { @@ -481,7 +284,7 @@ func (r *Result) CacheControl() string { return r.cacheControl } -// func (c *gcontext) addTrace(sel []qcode.Select, id int32, st time.Time) { +// func (c *gstate) addTrace(sel []qcode.Select, id int32, st time.Time) { // et := time.Now() // du := et.Sub(st) @@ -524,16 +327,19 @@ func (r *Result) CacheControl() string { // append(c.res.Extensions.Tracing.Execution.Resolvers, tr) // } -func (c *gcontext) debugLog(st *stmt) { - if st == nil || st.qc == nil { +func (s *gstate) debugLogStmt() { + st := s.cs.st + + if st.qc == nil { return } + for _, sel := range st.qc.Selects { if sel.SkipRender == qcode.SkipTypeUserNeeded { - c.gj.log.Printf("Field skipped, requires $user_id or table not added to anon role: %s", sel.FieldName) + s.gj.log.Printf("Field skipped, requires $user_id or table not added to anon role: %s", sel.FieldName) } if sel.SkipRender == qcode.SkipTypeBlocked { - c.gj.log.Printf("Field skipped, blocked: %s", sel.FieldName) + s.gj.log.Printf("Field skipped, blocked: %s", sel.FieldName) } } } @@ -542,39 +348,40 @@ func retryIfDBError(err error) bool { return (err == driver.ErrBadConn) } -func (gj *graphjin) saveToAllowList(actionVar json.RawMessage, query, namespace string) error { +func (gj *graphjin) saveToAllowList(qc *qcode.QCode, vars json.RawMessage, ns string) (err error) { if gj.conf.DisableAllowList { return nil } - return gj.allowList.Set(actionVar, query, namespace) -} + item := allow.Item{ + Namespace: ns, + Name: qc.Name, + Query: qc.Query, + Fragments: make([]allow.Fragment, len(qc.Fragments)), + } -func (gj *graphjin) spanStart(c context.Context, name string) (context.Context, span) { - return gj.tracer.Start(c, name) -} + if qc.ActionVar != "" { + var buf bytes.Buffer + if err = jsn.Clear(&buf, []byte(vars)); err != nil { + return + } -func (qres *queryResp) sql() string { - if qres.qc != nil { - return qres.qc.st.sql + v := json.RawMessage(buf.Bytes()) + item.Vars, err = json.MarshalIndent(v, "", " ") + if err != nil { + return + } } - return "" -} -func (qres *queryResp) actionVar() json.RawMessage { - if qcomp := qres.qc; qcomp != nil { - if v, ok := qcomp.st.qc.Vars[qcomp.st.qc.ActionVar]; ok { - return v - } + for i, f := range qc.Fragments { + item.Fragments[i] = allow.Fragment{Name: f.Name, Value: f.Value} } - return nil + + return gj.allowList.Set(item) } -func (qres *queryResp) cacheHeader() string { - if qres.qc != nil && qres.qc.st.qc != nil { - return qres.qc.st.qc.Cache.Header - } - return "" +func (gj *graphjin) spanStart(c context.Context, name string) (context.Context, span) { + return gj.tracer.Start(c, name) } func retryOperation(c context.Context, fn func() error) error { diff --git a/core/core_test.go b/core/core_test.go index 0821b83b..3267b92b 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -210,6 +210,41 @@ func TestAllowListWithNamespace(t *testing.T) { assert.ErrorIs(t, err, allow.ErrUnknownGraphQLQuery) } +func TestEnableSchema(t *testing.T) { + gql := ` + fragment Product on products { + id + name + } + query getProducts { + products(id: 2) { + ...Product + } + }` + + dir, err := os.MkdirTemp("", "test") + assert.NoError(t, err) + defer os.RemoveAll(dir) + + fs := fs.NewOsFSWithBase(dir) + + conf1 := newConfig(&core.Config{DBType: dbType, EnableSchema: true}) + gj1, err := core.NewGraphJin(conf1, db, core.OptionSetFS(fs)) + assert.NoError(t, err) + + res1, err := gj1.GraphQL(context.Background(), gql, nil, nil) + assert.NoError(t, err) + assert.Equal(t, stdJSON(res1.Data), `{"products":{"id":2,"name":"Product 2"}}`) + + conf2 := newConfig(&core.Config{DBType: dbType, EnableSchema: true, Production: true}) + gj2, err := core.NewGraphJin(conf2, db, core.OptionSetFS(fs)) + assert.NoError(t, err) + + res2, err := gj2.GraphQL(context.Background(), gql, nil, nil) + assert.NoError(t, err) + assert.Equal(t, stdJSON(res2.Data), `{"products":{"id":2,"name":"Product 2"}}`) +} + func TestConfigReuse(t *testing.T) { gql := `query { products(id: 2) { @@ -280,7 +315,7 @@ func TestParallelRuns(t *testing.T) { for n := 0; n < 10; n++ { conf := newConfig(&core.Config{ DBType: dbType, - Production: true, + Production: false, DisableAllowList: true, Tables: []core.Table{ {Name: "me", Table: "users"}, @@ -304,7 +339,6 @@ func TestParallelRuns(t *testing.T) { if err != nil { return fmt.Errorf("%d: %w", x, err) } - // fmt.Println(x, ">", string(res.Data)) } return nil }) diff --git a/core/gstate.go b/core/gstate.go new file mode 100644 index 00000000..fec48c71 --- /dev/null +++ b/core/gstate.go @@ -0,0 +1,372 @@ +package core + +import ( + "bytes" + "context" + "crypto/sha256" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strconv" + "sync" + + "github.com/dosco/graphjin/v2/core/internal/psql" + "github.com/dosco/graphjin/v2/core/internal/qcode" + "github.com/dosco/graphjin/v2/core/internal/valid" + plugin "github.com/dosco/graphjin/v2/plugin" +) + +type gstate struct { + gj *graphjin + r graphqlReq + cs *cstate + data []byte + dhash [sha256.Size]byte + role string +} + +type cstate struct { + sync.Once + st stmt + err error +} + +type stmt struct { + role string + roc *Role + qc *qcode.QCode + md psql.Metadata + va *valid.Validate + sql string +} + +func newGState(gj *graphjin, r graphqlReq, role string) (s gstate) { + s.gj = gj + s.r = r + s.role = role + return +} + +func (s *gstate) compile() (err error) { + if !s.gj.prod { + err = s.compileQueryForRole() + return + + } + + // In production mode and compile and cache the result + // In production mode the query is derived from the allow list + err = s.compileQueryForRoleOnce() + return +} + +func (s *gstate) compileQueryForRoleOnce() (err error) { + k := (s.r.ns + s.r.name + s.role) + + val, loaded := s.gj.queries.LoadOrStore(k, &cstate{}) + s.cs = val.(*cstate) + err = s.cs.err + + if loaded { + return + } + + s.cs.Do(func() { + err = s.compileQueryForRole() + s.cs.err = err + }) + return +} + +func (s *gstate) compileQueryForRole() (err error) { + st := stmt{role: s.role} + + var ok bool + if st.roc, ok = s.gj.roles[s.role]; !ok { + err = fmt.Errorf(`roles '%s' not defined in c.gj.config`, s.role) + return + } + + var vars json.RawMessage + if len(s.r.aschema) != 0 { + vars = s.r.aschema + } else { + vars = s.r.vars + } + + if st.qc, err = s.gj.qc.Compile( + s.r.query, + vars, + s.role, + s.r.ns); err != nil { + return + } + + var w bytes.Buffer + + if st.md, err = s.gj.pc.Compile(&w, st.qc); err != nil { + return + } + + if st.qc.Validation.Source != "" { + vc, ok := s.gj.validatorMap[st.qc.Validation.Type] + if !ok { + err = fmt.Errorf("no validator found for '%s'", st.qc.Validation.Type) + return + } + + var ve plugin.ValidationExecuter + ve, err = vc.CompileValidation(st.qc.Validation.Source) + if err != nil { + return + } + st.qc.Validation.VE = ve + st.qc.Validation.Exists = true + } + + if st.qc.Script.Name != "" { + if err = s.gj.loadScript(st.qc); err != nil { + return + } + } + + st.va = valid.New() + st.sql = w.String() + + if s.cs == nil { + s.cs = &cstate{st: st} + } else { + // s.cs.r = s.r + s.cs.st = st + } + + return +} + +func (s *gstate) compileAndExecuteWrapper(c context.Context) (err error) { + if err = s.compileAndExecute(c); err != nil { + return + } + + if s.gj.conf.Debug { + s.debugLogStmt() + } + + if len(s.data) == 0 { + return + } + + cs := s.cs + + if cs.st.qc.Remotes != 0 { + if err = s.execRemoteJoin(c); err != nil { + return + } + } + + qc := cs.st.qc + + if qc.Script.Exists && qc.Script.HasRespFn() { + err = s.scriptCallResp(c) + } + return +} + +func (s *gstate) compileAndExecute(c context.Context) (err error) { + var conn *sql.Conn + + // get a new database connection + c1, span1 := s.gj.spanStart(c, "Get Connection") + defer span1.End() + + err = retryOperation(c1, func() (err1 error) { + conn, err1 = s.gj.db.Conn(c1) + return + }) + if err != nil { + span1.Error(err) + return + } + defer conn.Close() + + // set the local user id on the connection if needed + if s.gj.conf.SetUserID { + c1, span2 := s.gj.spanStart(c, "Set Local User ID") + defer span2.End() + + err = retryOperation(c1, func() (err1 error) { + return s.setLocalUserID(c1, conn) + }) + if err != nil { + span2.Error(err) + return + } + } + + // get the role from context or using the role_query + if v := c.Value(UserRoleKey); v != nil { + s.role = v.(string) + } else if s.gj.abacEnabled { + err = s.executeRoleQuery(c, conn) + } + if err != nil { + return + } + + // compile query for the role + if err = s.compile(); err != nil { + return + } + err = s.execute(c, conn) + return +} + +func (s *gstate) execute(c context.Context, conn *sql.Conn) (err error) { + if err = s.validateAndUpdateVars(c); err != nil { + return + } + + var args args + if args, err = s.argList(c); err != nil { + return + } + + cs := s.cs + + c1, span := s.gj.spanStart(c, "Execute Query") + defer span.End() + + err = retryOperation(c1, func() (err1 error) { + return conn. + QueryRowContext(c1, cs.st.sql, args.values...). + Scan(&s.data) + }) + + if err != nil && err != sql.ErrNoRows { + span.Error(err) + } + + if span.IsRecording() { + span.SetAttributesString( + stringAttr{"query.namespace", s.r.ns}, + stringAttr{"query.operation", cs.st.qc.Type.String()}, + stringAttr{"query.name", cs.st.qc.Name}, + stringAttr{"query.role", cs.st.role}) + } + + if err == sql.ErrNoRows { + err = nil + } + if err != nil { + return + } + + s.dhash = sha256.Sum256(s.data) + + s.data, err = encryptValues(s.data, + s.gj.pf, decPrefix, s.dhash[:], s.gj.encKey) + + return +} + +func (s *gstate) executeRoleQuery(c context.Context, conn *sql.Conn) (err error) { + s.role, err = s.gj.executeRoleQuery(c, conn, s.r.vars, s.r.rc) + return +} + +func (s *gstate) argList(c context.Context) (args args, err error) { + args, err = s.gj.argList(c, s.cs.st.md, s.r.vars, s.r.rc) + return +} + +func (s *gstate) argListVars(c context.Context, vars json.RawMessage) ( + args args, err error) { + args, err = s.gj.argList(c, s.cs.st.md, vars, s.r.rc) + return +} + +func (s *gstate) setLocalUserID(c context.Context, conn *sql.Conn) (err error) { + if v := c.Value(UserIDKey); v == nil { + return nil + } else { + switch v1 := v.(type) { + case string: + _, err = conn.ExecContext(c, `SET SESSION "user.id" = '`+v1+`'`) + + case int: + _, err = conn.ExecContext(c, `SET SESSION "user.id" = `+strconv.Itoa(v1)) + } + } + return +} + +func (s *gstate) validateAndUpdateVars(c context.Context) (err error) { + var vars map[string]interface{} + + cs := s.cs + qc := cs.st.qc + + if qc == nil { + return nil + } + + if qc.Consts != nil || (qc.Script.Exists && qc.Script.HasReqFn()) { + vars = make(map[string]interface{}) + + if len(s.r.vars) != 0 { + if err := json.Unmarshal(s.r.vars, &vars); err != nil { + return err + } + } + } + + if qc.Validation.Exists { + if err := qc.Validation.VE.Validate(s.r.vars); err != nil { + return err + } + } + + if qc.Consts != nil { + errs := cs.st.va.ValidateMap(c, vars, qc.Consts) + if !s.gj.prod && len(errs) != 0 { + for k, v := range errs { + s.gj.log.Printf("validation failed: $%s: %s", k, v.Error()) + } + } + + if len(errs) != 0 { + return errors.New("validation failed") + } + } + + if qc.Script.Exists && qc.Script.HasReqFn() { + var v []byte + if v, err = s.scriptCallReq(c, qc, vars, s.role); err != nil { + return + } + s.r.vars = v + } + return +} + +func (s *gstate) sql() string { + if s.cs != nil && s.cs.st.qc != nil { + return s.cs.st.sql + } + return "" +} + +func (s *gstate) cacheHeader() string { + if s.cs != nil && s.cs.st.qc != nil { + return s.cs.st.qc.Cache.Header + } + return "" +} + +func (s *gstate) qcode() *qcode.QCode { + if s.cs != nil && s.cs.st.qc != nil { + return s.cs.st.qc + } + return nil +} diff --git a/core/insert_test.go b/core/insert_test.go index bb223164..c40b5670 100644 --- a/core/insert_test.go +++ b/core/insert_test.go @@ -6,8 +6,12 @@ import ( "context" "encoding/json" "fmt" + "os" + "testing" "github.com/dosco/graphjin/v2/core" + "github.com/dosco/graphjin/v2/plugin/fs" + "github.com/stretchr/testify/assert" ) func Example_insert() { @@ -662,3 +666,71 @@ func Example_insertIntoRecursiveRelationshipAndConnectTable2() { } // Output: {"comments":{"commenter":{"id":3},"comments":[{"id":6}],"id":5004,"product":{"id":26}}} } + +func TestAllowListWithMutations(t *testing.T) { + gql := ` + mutation getProducts { + users(insert: $data) { + id + } + }` + + dir, err := os.MkdirTemp("", "test") + assert.NoError(t, err) + defer os.RemoveAll(dir) + + fs := fs.NewOsFSWithBase(dir) + err = fs.CreateDir("queries") + assert.NoError(t, err) + + conf1 := newConfig(&core.Config{DBType: dbType, DisableAllowList: false}) + gj1, err := core.NewGraphJin(conf1, db, core.OptionSetFS(fs)) + assert.NoError(t, err) + + vars1 := json.RawMessage(`{ + "data": { + "id": 90011, + "email": "user90011@test.com", + "full_name": "User 90011" + } + }`) + + exp1 := `{"users": [{"id": 90011}]}` + + res1, err := gj1.GraphQL(context.Background(), gql, vars1, nil) + assert.NoError(t, err) + assert.Equal(t, exp1, string(res1.Data)) + + conf2 := newConfig(&core.Config{DBType: dbType, Production: true}) + gj2, err := core.NewGraphJin(conf2, db, core.OptionSetFS(fs)) + assert.NoError(t, err) + + vars2 := json.RawMessage(`{ + "data": { + "id": 90012, + "email": "user90012@test.com", + "full_name": "User 90012" + } + }`) + + exp2 := `{"users": [{"id": 90012}]}` + + res2, err := gj2.GraphQL(context.Background(), gql, vars2, nil) + assert.NoError(t, err) + assert.Equal(t, exp2, string(res2.Data)) + + vars3 := json.RawMessage(`{ + "data": { + "id": 90013, + "email": "user90013@test.com", + "full_name": "User 90013", + "stripe_id": "payment_id_90013" + } + }`) + + exp3 := `{"users": [{"id": 90013}]}` + + res3, err := gj2.GraphQL(context.Background(), gql, vars3, nil) + assert.NoError(t, err) + assert.Equal(t, exp3, string(res3.Data)) +} diff --git a/core/internal/allow/allow.go b/core/internal/allow/allow.go index 0423a6a1..7bf82028 100644 --- a/core/internal/allow/allow.go +++ b/core/internal/allow/allow.go @@ -7,28 +7,15 @@ import ( "fmt" _log "log" "path/filepath" - "strconv" "strings" - "text/scanner" - "gopkg.in/yaml.v3" - - "github.com/chirino/graphql/schema" "github.com/dosco/graphjin/v2/core/internal/graph" - "github.com/dosco/graphjin/v2/internal/jsn" "github.com/dosco/graphjin/v2/plugin" lru "github.com/hashicorp/golang-lru" ) var ErrUnknownGraphQLQuery = errors.New("unknown graphql query") -const ( - expComment = iota + 1 - expVar - expQuery - expFrag -) - const ( queryPath = "/queries" fragmentPath = "/fragments" @@ -36,16 +23,16 @@ const ( type Item struct { Namespace string - Name string Operation string - Query string - Vars string - frags []Frag + Name string + Vars json.RawMessage + Query []byte + Fragments []Fragment } -type Frag struct { +type Fragment struct { Name string - Value string + Value []byte } type List struct { @@ -81,7 +68,7 @@ func New(log *_log.Logger, fs plugin.FS, readOnly bool) (al *List, err error) { if !ok { break } - err = al.save(v, false) + err = al.save(v) if err != nil && log != nil { log.Println("WRN allow list save:", err) } @@ -91,52 +78,19 @@ func New(log *_log.Logger, fs plugin.FS, readOnly bool) (al *List, err error) { return al, err } -func (al *List) Set(vars json.RawMessage, query string, namespace string) error { +func (al *List) Set(item Item) error { if al.saveChan == nil { return errors.New("allow list is read-only") } - if query == "" { + if len(item.Query) == 0 { return errors.New("empty query") } - item, err := parseQuery(query) - if err != nil { - return err - } - - item.Namespace = namespace - item.Vars = string(vars) al.saveChan <- item return nil } -func (al *List) Upgrade() (err error) { - files, err := al.fs.ReadDir(queryPath) - if err != nil { - return fmt.Errorf("%w (%s)", err, queryPath) - } - - for _, f := range files { - if f.IsDir() { - continue - } - ext := filepath.Ext(f.Name()) - if ext != ".yaml" && ext != ".yml" { - continue - } - item, err := al.Get(filepath.Join(queryPath, f.Name())) - if err != nil { - return err - } - - if err := al.save(item, false); err != nil { - return err - } - } - return -} - func (al *List) GetByName(name string, useCache bool) (item Item, err error) { if useCache { if v, ok := al.cache.Get(name); ok { @@ -145,344 +99,116 @@ func (al *List) GetByName(name string, useCache bool) (item Item, err error) { } } - fpath := filepath.Join(queryPath, name) - exts := []string{".gql", ".graphql", ".yml", ".yaml"} - for _, ext := range exts { - if item, err = al.Get((fpath + ext)); err == nil { - break - } else if err != plugin.ErrNotFound { - return item, err - } - } + fp := filepath.Join(queryPath, name) + var ok bool - if useCache && err == nil { - al.cache.Add(name, item) + if ok, err = al.fs.Exists((fp + ".gql")); err != nil { + return + } else if ok { + item, err = al.get(queryPath, name, ".gql", useCache) + return } - return -} - -var errUnknownFileType = errors.New("not a graphql file") - -func (al *List) Get(filePath string) (item Item, err error) { - switch filepath.Ext(filePath) { - case ".gql", ".graphql": - return itemFromGQL(al.fs, filePath) - case ".yml", ".yaml": - return itemFromYaml(al.fs, filePath) - default: - return item, errUnknownFileType + if ok, err = al.fs.Exists((fp + ".graphql")); err != nil { + return + } else if ok { + item, err = al.get(queryPath, name, ".gql", useCache) + } else { + err = ErrUnknownGraphQLQuery } + return } -func itemFromYaml(fs plugin.FS, filePath string) (Item, error) { - var item Item +func (al *List) get(queryPath, name, ext string, useCache bool) (item Item, err error) { + queryNS, queryName := splitName(name) - b, err := fs.ReadFile(filePath) + var query []byte + query, err = readGQL(al.fs, filepath.Join(queryPath, (name+ext))) if err != nil { - return item, err - } - - if err := yaml.Unmarshal(b, &item); err != nil { - return item, err - } - - h, err := graph.FastParse(item.Query) - if err != nil { - return item, err - } - item.Operation = h.Operation - - qi, err := parseQuery(item.Query) - if err != nil { - return item, err - } - - for _, f := range qi.frags { - b, err := fs.ReadFile(filepath.Join(fragmentPath, f.Name)) - if err != nil { - return item, err - } - item.frags = append(item.frags, Frag{Name: f.Name, Value: string(b)}) - } - - return item, nil -} - -func itemFromGQL(fs plugin.FS, filePath string) (item Item, err error) { - fn := filepath.Base(filePath) - fn = strings.TrimSuffix(fn, filepath.Ext(fn)) - queryNS, queryName := splitName(fn) - - if queryName == "" { - return item, fmt.Errorf("invalid filename: %s", filePath) + return } - query, err := parseGQLFile(fs, filePath) + var h graph.FPInfo + h, err = graph.FastParseBytes(query) if err != nil { - return item, err + return } - h, err := graph.FastParse(query) - if err != nil { - return item, err + vars, err1 := al.fs.ReadFile(filepath.Join(queryPath, (name + ".json"))) + if err1 != nil && err1 != plugin.ErrNotFound { + return } item.Namespace = queryNS item.Operation = h.Operation item.Name = queryName item.Query = query - return item, nil -} - -func parseQuery(b string) (Item, error) { - var s scanner.Scanner - s.Init(strings.NewReader(b)) - s.Mode ^= scanner.SkipComments - - var op, sp scanner.Position - var item Item - var err error - - st := expComment - period := 0 - frags := make(map[string]struct{}) - - for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() { - txt := s.TokenText() - - switch { - case strings.HasPrefix(txt, "/*"): - v := b[sp.Offset:s.Pos().Offset] - item, err = setValue(st, v, item) - sp = s.Pos() - - case strings.HasPrefix(txt, "variables"): - v := b[sp.Offset:s.Pos().Offset] - item, err = setValue(st, v, item) - sp = s.Pos() - st = expVar - - case isGraphQL(txt): - v := b[sp.Offset:s.Pos().Offset] - item, err = setValue(st, v, item) - sp = op - st = expQuery - - case strings.HasPrefix(txt, "fragment"): - v := b[sp.Offset:s.Pos().Offset] - item, err = setValue(st, v, item) - sp = op - st = expFrag - case txt == "@": - exp := []string{"json", "(", "schema", ":"} - if ok := expTokens(&s, exp); !ok { - continue - } - s.Scan() - txt = s.TokenText() - if txt == ":" { - s.Scan() - txt = s.TokenText() - } - if txt == "" { - continue - } - vars, err := strconv.Unquote(txt) - if err != nil { - return item, err - } - item.Vars = strings.TrimSpace(vars) - default: - if period == 3 && txt != "." { - frags[txt] = struct{}{} - } - if period != 3 && txt == "." { - period++ - } else { - period = 0 - } - } - - if err != nil { - return item, err - } - - op = s.Pos() - } - - if st == expQuery || st == expFrag { - v := b[sp.Offset:s.Pos().Offset] - item, err = setValue(st, v, item) - } - - if err != nil { - return item, err - } - - for k := range frags { - item.frags = append(item.frags, Frag{Name: k}) - } - return item, nil -} + item.Vars = vars -func expTokens(s *scanner.Scanner, exp []string) (ok bool) { - for _, v := range exp { - if tok := s.Scan(); tok == scanner.EOF { - return - } - txt := s.TokenText() - if txt != v { - return - } - } - return true -} - -func setValue(st int, v string, item Item) (Item, error) { - val := func() string { - return strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1]) - } - switch st { - case expVar: - item.Vars = val() - - case expQuery: - item.Query = val() - - case expFrag: - f := Frag{Value: val()} - f.Name = fragmentName(f.Value) - item.frags = append(item.frags, f) + if useCache { + al.cache.Add(name, item) } - - return item, nil + return } -func (al *List) save(item Item, safe bool) error { - var buf bytes.Buffer - var err error - - qd := &schema.QueryDocument{} - if err := qd.Parse(item.Query); err != nil { - return err - } - - qvars := strings.TrimSpace(item.Vars) - if qvars != "" && qvars != "{}" { - if err := jsn.Clear(&buf, []byte(qvars)); err != nil { - return err - } - - vj := json.RawMessage(buf.Bytes()) - if vj, err = json.MarshalIndent(vj, "", " "); err != nil { - return err - } - buf.Reset() - - d := schema.Directive{ - Name: "json", - Args: schema.ArgumentList{{ - Name: "schema", - Value: schema.ToLiteral(string(vj)), - }}, - } - qd.Operations[0].Directives = append(qd.Operations[0].Directives, &d) - // Bug in chirino/graphql forces us to add the space after the query name - qd.Operations[0].Name = qd.Operations[0].Name + " " - } - qd.WriteTo(&buf) - - item.Name = strings.TrimSpace(qd.Operations[0].Name) +func (al *List) save(item Item) (err error) { + item.Name = strings.TrimSpace(item.Name) if item.Name == "" { - return errors.New("no query name defined: only named queries are saved to the allow list") + err = errors.New("no query name defined: only named queries are saved to the allow list") + return } - - return al.saveItem( - item.Namespace, - item.Name, - buf.String(), - item.frags, - safe) + return al.saveItem(item) } -func (al *List) saveItem( - ns, name, content string, frags []Frag, safe bool) error { - - var qfn string - if ns != "" { - qfn = ns + "." + name + ".gql" +func (al *List) saveItem(item Item) (err error) { + var queryFile string + if item.Namespace != "" { + queryFile = item.Namespace + "." + item.Name } else { - qfn = name + ".gql" + queryFile = item.Name } - var gqlContent bytes.Buffer - fmap := make(map[string]struct{}) + fmap := make(map[string]struct{}, len(item.Fragments)) + var buf bytes.Buffer - for _, fv := range frags { - var fn string - if ns != "" { - fn = ns + "." + fv.Name + for _, f := range item.Fragments { + var fragFile string + if item.Namespace != "" { + fragFile = item.Namespace + "." + f.Name } else { - fn = fv.Name + fragFile = f.Name } - fn += ".gql" - if _, ok := fmap[fn]; !ok { - fh := fmt.Sprintf(`#import "./fragments/%s"`, fn) - gqlContent.WriteString(fh) - gqlContent.WriteRune('\n') - fmap[fn] = struct{}{} + if _, ok := fmap[fragFile]; !ok { + fh := fmt.Sprintf(`#import "./fragments/%s"`, fragFile) + buf.WriteString(fh) + buf.WriteRune('\n') + fmap[fragFile] = struct{}{} } - fragFile := filepath.Join(queryPath, "fragments", fn) - if safe { - if ok, err := al.fs.Exists(fragFile); ok { - continue - } else if err != nil { - return err - } - } - - err := al.fs.CreateFile(fragFile, []byte(fv.Value)) + ff := filepath.Join(queryPath, "fragments", (fragFile + ".gql")) + err = al.fs.CreateFile(ff, []byte(f.Value)) if err != nil { - return err + return } } - if gqlContent.Len() != 0 { - gqlContent.WriteRune('\n') - } - gqlContent.WriteString(content) - - queryFile := filepath.Join(queryPath, qfn) - if safe { - if ok, err := al.fs.Exists(queryFile); ok { - return nil - } else if err != nil { - return err - } + if buf.Len() != 0 { + buf.WriteRune('\n') } + buf.Write(bytes.TrimSpace(item.Query)) - err := al.fs.CreateFile(queryFile, gqlContent.Bytes()) + qf := filepath.Join(queryPath, (queryFile + ".gql")) + err = al.fs.CreateFile(qf, bytes.TrimSpace(buf.Bytes())) if err != nil { - return err + return } - return nil -} -// func (al *List) fetchFragment(namespace, name string) (string, error) { -// var fn string -// if namespace != "" { -// fn = namespace + "." + name -// } else { -// fn = name -// } -// v, err := al.fs.ReadFile(filepath.Join(fragmentPath, fn)) -// if err != nil { -// return "", err -// } -// return string(v), err -// } + if len(item.Vars) != 0 { + jf := filepath.Join(queryPath, (queryFile + ".json")) + err = al.fs.CreateFile(jf, bytes.TrimSpace(item.Vars)) + } + return +} func splitName(v string) (string, string) { i := strings.LastIndex(v, ".") diff --git a/core/internal/allow/allow_test.go b/core/internal/allow/allow_test.go deleted file mode 100644 index 71ae7c35..00000000 --- a/core/internal/allow/allow_test.go +++ /dev/null @@ -1,277 +0,0 @@ -package allow - -import ( - "testing" - - "github.com/dosco/graphjin/v2/core/internal/graph" -) - -func TestGQLName1(t *testing.T) { - var q = ` - query { - products( - distinct: [price] - where: { id: { and: { greater_or_equals: 20, lt: 28 } } } - ) { id name } }` - - h, err := graph.FastParse(q) - if err != nil { - t.Fatal(err) - } - - if h.Name != "" { - t.Fatal("Name should be empty, not ", h.Name) - } -} - -func TestGQLName2(t *testing.T) { - var q = ` - query hakuna_matata - - { - products( - distinct: [price] - where: { id: { and: { greater_or_equals: 20, lt: 28 } } } - ) { - id - name - } - }` - - h, err := graph.FastParse(q) - if err != nil { - t.Fatal(err) - } - - if h.Name != "hakuna_matata" { - t.Fatal("Name should be 'hakuna_matata', not ", h.Name) - } -} - -func TestGQLName3(t *testing.T) { - var q = ` - mutation means{ users { id } }` - - // var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` - - h, err := graph.FastParse(q) - if err != nil { - t.Fatal(err) - } - - if h.Name != "means" { - t.Fatal("Name should be 'means', not ", h.Name) - } -} - -func TestGQLName4(t *testing.T) { - var q = ` - query no_worries - users { - id - } - }` - - h, err := graph.FastParse(q) - if err != nil { - t.Fatal(err) - } - - if h.Name != "no_worries" { - t.Fatal("Name should be 'no_worries', not ", h.Name) - } -} - -func TestGQLName5(t *testing.T) { - var q = ` - { - users { - id - } - }` - - h, err := graph.FastParse(q) - if err != nil { - t.Fatal(err) - } - - if h.Name != "" { - t.Fatal("Name should be empty, not ", h.Name) - } -} - -func TestParse1(t *testing.T) { - var al = ` - # Hello world - - variables { - "data": { - "slug": "", - "body": "", - "post": { - "connect": { - "slug": "" - } - } - } - } - - mutation createComment { - comment(insert: $data) { - slug - body - createdAt: created_at - totalVotes: cached_votes_total - totalReplies: cached_replies_total - vote: comment_vote(where: {user_id: {eq: $user_id}}) { - created_at - __typename - } - author: user { - slug - firstName: first_name - lastName: last_name - pictureURL: picture_url - bio - __typename - } - __typename - } - } - - # Query named createPost - - query createPost { - post(insert: $data) { - slug - body - published - createdAt: created_at - totalVotes: cached_votes_total - totalComments: cached_comments_total - vote: post_vote(where: {user_id: {eq: $user_id}}) { - created_at - __typename - } - author: user { - slug - firstName: first_name - lastName: last_name - pictureURL: picture_url - bio - __typename - } - __typename - } - }` - - _, err := parseQuery(al) - if err != nil { - t.Fatal(err) - } -} - -func TestParse2(t *testing.T) { - var al = ` - /* Hello world */ - - variables { - "data": { - "slug": "", - "body": "", - "post": { - "connect": { - "slug": "" - } - } - } - } - - mutation createComment { - comment(insert: $data) { - slug - body - createdAt: created_at - totalVotes: cached_votes_total - totalReplies: cached_replies_total - vote: comment_vote(where: {user_id: {eq: $user_id}}) { - created_at - __typename - } - author: user { - slug - firstName: first_name - lastName: last_name - pictureURL: picture_url - bio - __typename - } - __typename - } - } - - /* - Query named createPost - */ - - variables { - "data": { - "thread": { - "connect": { - "slug": "" - } - }, - "slug": "", - "published": false, - "body": "" - } - } - - query createPost { - post(insert: $data) { - slug - body - published - createdAt: created_at - totalVotes: cached_votes_total - totalComments: cached_comments_total - vote: post_vote(where: {user_id: {eq: $user_id}}) { - created_at - __typename - } - author: user { - slug - firstName: first_name - lastName: last_name - pictureURL: picture_url - bio - __typename - } - __typename - } - }` - - _, err := parseQuery(al) - if err != nil { - t.Fatal(err) - } -} - -func TestParse3(t *testing.T) { - var query = ` - mutation createCommentAndProduct @json(schema::"{\n \"data\": {\n \"body\": \"\",\n \"created_at\": \"\",\n \"updated_at\": \"\",\n \"product\": {\n \"connect\": {\n \"id\": 0.0\n }\n }\n }\n}") { - comment(insert:$data) { - id - product { - id - name - } - } - } - ` - - _, err := parseQuery(query) - if err != nil { - t.Fatal(err) - } -} diff --git a/core/internal/allow/deprecate.go b/core/internal/allow/deprecate.go new file mode 100644 index 00000000..bb0ac28d --- /dev/null +++ b/core/internal/allow/deprecate.go @@ -0,0 +1,180 @@ +package allow + +/* +import ( + "path/filepath" + "strconv" + "strings" + "text/scanner" + + "github.com/dosco/graphjin/v2/core/internal/graph" + "github.com/dosco/graphjin/v2/plugin" + "gopkg.in/yaml.v2" +) + +const ( + expComment = iota + 1 + expVar + expQuery + expFrag +) + +func itemFromYaml(fs plugin.FS, filePath string) (Item, error) { + var item Item + + b, err := fs.ReadFile(filePath) + if err != nil { + return item, err + } + + if err := yaml.Unmarshal(b, &item); err != nil { + return item, err + } + + h, err := graph.FastParse(item.Query) + if err != nil { + return item, err + } + item.Operation = h.Operation + + qi, err := parseQuery(item.Query) + if err != nil { + return item, err + } + + for _, f := range qi.frags { + b, err := fs.ReadFile(filepath.Join(fragmentPath, f.Name)) + if err != nil { + return item, err + } + item.frags = append(item.frags, Frag{Name: f.Name, Value: string(b)}) + } + + return item, nil +} + +func parseQuery(b string) (Item, error) { + var s scanner.Scanner + s.Init(strings.NewReader(b)) + s.Mode ^= scanner.SkipComments + + var op, sp scanner.Position + var item Item + var err error + + st := expComment + period := 0 + frags := make(map[string]struct{}) + + for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() { + txt := s.TokenText() + + switch { + case strings.HasPrefix(txt, "/*"): + v := b[sp.Offset:s.Pos().Offset] + item, err = setValue(st, v, item) + sp = s.Pos() + + case strings.HasPrefix(txt, "variables"): + v := b[sp.Offset:s.Pos().Offset] + item, err = setValue(st, v, item) + sp = s.Pos() + st = expVar + + case isGraphQL(txt): + v := b[sp.Offset:s.Pos().Offset] + item, err = setValue(st, v, item) + sp = op + st = expQuery + + case strings.HasPrefix(txt, "fragment"): + v := b[sp.Offset:s.Pos().Offset] + item, err = setValue(st, v, item) + sp = op + st = expFrag + + case txt == "@": + exp := []string{"json", "(", "schema", ":"} + if ok := expTokens(&s, exp); !ok { + continue + } + s.Scan() + txt = s.TokenText() + if txt == ":" { + s.Scan() + txt = s.TokenText() + } + if txt == "" { + continue + } + vars, err := strconv.Unquote(txt) + if err != nil { + return item, err + } + item.Vars = strings.TrimSpace(vars) + default: + if period == 3 && txt != "." { + frags[txt] = struct{}{} + } + if period != 3 && txt == "." { + period++ + } else { + period = 0 + } + } + + if err != nil { + return item, err + } + + op = s.Pos() + } + + if st == expQuery || st == expFrag { + v := b[sp.Offset:s.Pos().Offset] + item, err = setValue(st, v, item) + } + + if err != nil { + return item, err + } + + for k := range frags { + item.frags = append(item.frags, Frag{Name: k}) + } + return item, nil +} + +func expTokens(s *scanner.Scanner, exp []string) (ok bool) { + for _, v := range exp { + if tok := s.Scan(); tok == scanner.EOF { + return + } + txt := s.TokenText() + if txt != v { + return + } + } + return true +} + +func setValue(st int, v string, item Item) (Item, error) { + val := func() string { + return strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1]) + } + switch st { + case expVar: + item.Vars = val() + + case expQuery: + item.Query = val() + + case expFrag: + f := Frag{Value: val()} + f.Name = fragmentName(f.Value) + item.frags = append(item.frags, f) + } + + return item, nil +} +*/ diff --git a/core/internal/allow/gql.go b/core/internal/allow/gql.go index 597425d6..25005974 100644 --- a/core/internal/allow/gql.go +++ b/core/internal/allow/gql.go @@ -3,46 +3,49 @@ package allow import ( "bufio" "bytes" + "io" "path/filepath" "regexp" - "strings" "github.com/dosco/graphjin/v2/plugin" ) var incRe = regexp.MustCompile(`(?m)#import \"(.+)\"`) -func parseGQLFile(fs plugin.FS, fname string) (string, error) { - var sb strings.Builder +func readGQL(fs plugin.FS, fname string) ([]byte, error) { + var b bytes.Buffer - if err := parseGQL(fs, fname, &sb); err == plugin.ErrNotFound { - return "", ErrUnknownGraphQLQuery + if err := parseGQL(fs, fname, &b); err == plugin.ErrNotFound { + return nil, ErrUnknownGraphQLQuery } else if err != nil { - return "", err + return nil, err } - return sb.String(), nil + return b.Bytes(), nil } -func parseGQL(fs plugin.FS, fname string, sb *strings.Builder) error { +func parseGQL(fs plugin.FS, fname string, r io.Writer) (err error) { b, err := fs.ReadFile(fname) if err != nil { return err } - s := bufio.NewScanner(bytes.NewReader(b)) for s.Scan() { m := incRe.FindStringSubmatch(s.Text()) if len(m) == 0 { - sb.Write(s.Bytes()) + r.Write(s.Bytes()) //nolint: errcheck continue } - fn := filepath.Join(filepath.Dir(fname), m[1]) - if err := parseGQL(fs, fn, sb); err != nil { + incFile := m[1] + if filepath.Ext(incFile) == "" { + incFile += ".gql" + } + + fn := filepath.Join(filepath.Dir(fname), incFile) + if err := parseGQL(fs, fn, r); err != nil { return err } } - - return nil + return } diff --git a/core/internal/allow/util.go b/core/internal/allow/util.go deleted file mode 100644 index d71d730a..00000000 --- a/core/internal/allow/util.go +++ /dev/null @@ -1,36 +0,0 @@ -package allow - -import ( - "strings" -) - -func fragmentName(b string) string { - state, s := 0, 0 - bl := len(b) - - for i := 0; i < bl; i++ { - switch { - case state == 2 && !isValidNameChar(b[i]): - return b[s:i] - case state == 1 && b[i] == '{': - return "" - case state == 1 && isValidNameChar(b[i]): - s = i - state = 2 - case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y' || b[i-1] == 't'): - state = 1 - } - } - - return "" -} - -func isValidNameChar(c byte) bool { - return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' -} - -func isGraphQL(s string) bool { - return strings.HasPrefix(s, "query") || - strings.HasPrefix(s, "mutation") || - strings.HasPrefix(s, "subscription") -} diff --git a/core/internal/graph/gen_string.go b/core/internal/graph/gen_string.go index bfad4a8a..498391a5 100644 --- a/core/internal/graph/gen_string.go +++ b/core/internal/graph/gen_string.go @@ -21,17 +21,18 @@ func _() { _ = x[itemObjClose-10] _ = x[itemColon-11] _ = x[itemEquals-12] - _ = x[itemDirective-13] - _ = x[itemVariable-14] - _ = x[itemSpread-15] - _ = x[itemNumberVal-16] - _ = x[itemStringVal-17] - _ = x[itemBoolVal-18] + _ = x[itemRequired-13] + _ = x[itemDirective-14] + _ = x[itemVariable-15] + _ = x[itemSpread-16] + _ = x[itemNumberVal-17] + _ = x[itemStringVal-18] + _ = x[itemBoolVal-19] } -const _MType_name = "errorend of filelabel\"on\"punctuation !():=[]{|}()[]{}:=@(directive)$variable...numberstringboolean" +const _MType_name = "errorend of filelabel\"on\"punctuation !()[]{}:=()[]{}:=!@(directive)$variable...numberstringboolean" -var _MType_index = [...]uint8{0, 5, 16, 21, 25, 47, 48, 49, 50, 51, 52, 53, 54, 55, 67, 76, 79, 85, 91, 98} +var _MType_index = [...]uint8{0, 5, 16, 21, 25, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 67, 76, 79, 85, 91, 98} func (i MType) String() string { if i < 0 || i >= MType(len(_MType_index)-1) { diff --git a/core/internal/graph/lex.go b/core/internal/graph/lex.go index 652f390b..bdb481cd 100644 --- a/core/internal/graph/lex.go +++ b/core/internal/graph/lex.go @@ -9,6 +9,7 @@ import ( ) var ( + typeToken = []byte("type") queryToken = []byte("query") mutationToken = []byte("mutation") fragmentToken = []byte("fragment") @@ -45,7 +46,7 @@ const ( itemEOF // end of file itemName // label itemOn // "on" - itemPunctuator // punctuation !():=[]{|} + itemPunctuator // punctuation !()[]{}:= itemArgsOpen // ( itemArgsClose // ) itemListOpen // [ @@ -54,6 +55,7 @@ const ( itemObjClose // } itemColon // : itemEquals // = + itemRequired // ! itemDirective // @(directive) itemVariable // $variable itemSpread // ... @@ -62,8 +64,9 @@ const ( itemBoolVal // boolean ) -// !$():=@[]{|} +// !()[]{}:= var punctuators = map[rune]MType{ + '!': itemRequired, '{': itemObjOpen, '}': itemObjClose, '[': itemListOpen, @@ -248,6 +251,9 @@ func lexRoot(l *lexer) stateFn { case r == '@': l.ignore() l.emit(itemDirective) + case r == '!': + l.ignore() + l.emit(itemRequired) case r == '$': l.ignore() if l.acceptAlphaNum() { @@ -257,8 +263,6 @@ func lexRoot(l *lexer) stateFn { case contains(l.current(), punctuatorToken): if item, ok := punctuators[r]; ok { l.emit(item) - } else { - l.emit(itemPunctuator) } case r == '"' || r == '\'': l.backup() diff --git a/core/internal/graph/parse.go b/core/internal/graph/parse.go index cbe92850..bce5f9eb 100644 --- a/core/internal/graph/parse.go +++ b/core/internal/graph/parse.go @@ -51,12 +51,15 @@ type Operation struct { Directives []Directive Fields []Field fieldsA [10]Field + Query []byte + Frags []Fragment } type Fragment struct { Name string On string Fields []Field + Value []byte } type Field struct { @@ -114,10 +117,8 @@ type Parser struct { err error } -func Parse(gql []byte) (Operation, error) { +func Parse(gql []byte) (op Operation, err error) { var l lexer - var op Operation - var err error if len(gql) == 0 { return op, errors.New("empty query") @@ -134,7 +135,7 @@ func Parse(gql []byte) (Operation, error) { } op.Fields = op.fieldsA[:0] - s := -1 + qs := -1 qf := false for { @@ -145,108 +146,125 @@ func Parse(gql []byte) (Operation, error) { if p.peekVal(fragmentToken) { p.ignore() - if _, err := p.parseFragment(); err != nil { - return op, err + if _, err = p.parseFragment(); err != nil { + return } } else { if !qf && (p.peekVal(queryToken, mutationToken, fragmentToken, subscriptionToken) || p.peek(itemObjOpen)) { - s = p.pos + qs = p.pos qf = true } p.ignore() } } - p.reset(s) + p.reset(qs) if op, err = p.parseOp(); err != nil { return op, err } + op.Frags = make([]Fragment, 0, len(p.frags)) + for _, f := range p.frags { + op.Frags = append(op.Frags, f) + } + for i, f := range op.Fields { if f.ParentID == -1 && len(f.Args) == 0 && len(f.Children) == 0 { op.Fields[i].Type = FieldKeyword } } - return op, nil + return } -func (p *Parser) parseFragment() (Fragment, error) { - var err error - var frag Fragment +func (p *Parser) parseFragment() (frag Fragment, err error) { + s := p.curr().pos if p.peek(itemName) { frag.Name = p.val(p.next()) } else { - return frag, errors.New("fragment: missing name") + err = errors.New("fragment: missing name") + return } if p.peek(itemOn) { p.ignore() } else { - return frag, errors.New("fragment: missing 'on' keyword") + err = errors.New("fragment: missing 'on' keyword") + return } if p.peek(itemName) { frag.On = p.vall(p.next()) } else { - return frag, errors.New("fragment: missing table name after 'on' keyword") + err = errors.New("fragment: missing table name after 'on' keyword") + return } if p.peek(itemObjOpen) { p.ignore() } else { - return frag, fmt.Errorf("fragment: expecting a '{', got: %s", p.next()) + err = fmt.Errorf("fragment: expecting a '{', got: %s", p.next()) + return } frag.Fields, err = p.parseFields(frag.Fields) if err != nil { - return frag, fmt.Errorf("fragment: %v", err) + err = fmt.Errorf("fragment: %v", err) + return } - if p.frags == nil { - p.frags = make(map[string]Fragment) + if p.peek(itemObjClose) { + p.ignore() } - p.frags[frag.Name] = frag + e := p.curr().pos + 1 + frag.Value = p.input[s:e] + + if p.frags == nil { + p.frags = map[string]Fragment{ + frag.Name: frag, + } + } else { + p.frags[frag.Name] = frag + } return frag, nil } func (p *Parser) parseOp() (Operation, error) { var err error - var typeSet bool var op Operation - if p.peekVal(queryToken, mutationToken, subscriptionToken) { - if err = p.parseOpTypeAndArgs(&op); err != nil { - return op, fmt.Errorf("%s: %v", op.Type, err) - } - typeSet = true + s := p.curr().pos + 1 + + if !p.peekVal(queryToken, mutationToken, subscriptionToken) { + return op, fmt.Errorf("expecting a query, mutation or subscription, got: %s", p.peekNext()) } - if p.peek(itemObjOpen) { - p.ignore() - if !typeSet { - op.Type = OpQuery - } + if err = p.parseOpTypeAndArgs(&op); err != nil { + return op, fmt.Errorf("%s: %v", op.Type, err) + } - for { - if p.peek(itemEOF) || p.peekVal(fragmentToken) { - p.ignore() - break - } + if !p.peek(itemObjOpen) { + return op, p.tokErr("{") + } + p.ignore() - op.Fields, err = p.parseFields(op.Fields) - if err != nil { - return op, fmt.Errorf("%s: %v", op.Type, err) - } - } - } else { - return op, fmt.Errorf("expecting a query, mutation or subscription, got: %s", p.peekNext()) + op.Fields, err = p.parseFields(op.Fields) + if err != nil { + return op, fmt.Errorf("%s: %v", op.Type, err) } + + if p.peek(itemObjClose) { + p.ignore() + } + + e := p.curr().pos + 1 + op.Query = p.input[s:e] + return op, nil } @@ -313,7 +331,7 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) { st := NewStack() if !p.peek(itemName, itemSpread) { - return nil, fmt.Errorf("unexpected token: %s", p.peekNext()) + return nil, p.tokErr(`1 field name or ...Fragment`) } for { @@ -323,9 +341,8 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) { } if p.peek(itemObjClose) { - p.ignore() - if st.Len() != 0 { + p.ignore() st.Pop() continue } else { @@ -475,6 +492,8 @@ func (p *Parser) parseFragmentFields(st *Stack, fields []Field) ([]Field, error) func (p *Parser) parseField(f *Field) error { var err error + + // hold onto name to while we check if its an alias v := p.next() if p.peek(itemColon) { @@ -756,6 +775,13 @@ func (p *Parser) peekVal(values ...[]byte) bool { return false } +func (p *Parser) curr() item { + if p.pos == -1 { + return item{} + } + return p.items[p.pos] +} + func (p *Parser) next() item { n := p.pos + 1 if n >= len(p.items) { diff --git a/core/internal/graph/schema.go b/core/internal/graph/schema.go new file mode 100644 index 00000000..cf792935 --- /dev/null +++ b/core/internal/graph/schema.go @@ -0,0 +1,185 @@ +package graph + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "strings" +) + +type Schema struct { + Type string + Version string + Schema string + Types []Type +} + +type Type struct { + Name string + Directives []Directive + Fields []TField +} + +type TField struct { + Name string + Type string + Required bool + List bool + Directives []Directive +} + +func ParseSchema(schema []byte) (s Schema, err error) { + var l lexer + + if len(schema) == 0 { + err = errors.New("empty schema") + return + } + + hp := `# dbinfo:` + + r := bufio.NewReader(bytes.NewReader(schema)) + for { + var line []byte + if line, _, err = r.ReadLine(); err != nil { + return + } + h := string(line) + if h != "" && strings.HasPrefix(h, hp) { + v := strings.SplitN(h[len(hp):], ",", 3) + if len(v) >= 1 { + s.Type = v[0] + } + if len(v) >= 2 { + s.Version = v[1] + } + if len(v) >= 3 { + s.Schema = v[2] + } + break + } else if h != "" { + break + } + } + + if l, err = lex(schema); err != nil { + return + } + + p := Parser{ + input: l.input, + pos: -1, + items: l.items, + } + + for { + var t Type + + if p.peek(itemEOF) { + return + } + + if t, err = p.parseType(); err != nil { + return + } + s.Types = append(s.Types, t) + } +} + +func (p *Parser) parseType() (t Type, err error) { + if !p.peekVal(typeToken) { + err = p.tokErr(`type`) + return + } + p.ignore() + + if !p.peek(itemName) { + err = p.tokErr(`type name`) + return + } + t.Name = p.val(p.next()) + + for p.peek(itemDirective) { + p.ignore() + if t.Directives, err = p.parseDirective(t.Directives); err != nil { + return + } + } + + if !p.peek(itemObjOpen) { + err = p.tokErr(`{`) + return + } + p.ignore() + + for { + if p.peek(itemEOF) { + err = p.eofErr(`type ` + t.Name) + return + } + + if p.peek(itemObjClose) { + p.ignore() + return + } + + var f TField + + if !p.peek(itemName) { + err = p.tokErr(`field name`) + return + } + f.Name = p.val(p.next()) + + if !p.peek(itemColon) { + err = p.tokErr(`:`) + return + } + p.ignore() + + if p.peek(itemListOpen) { + p.ignore() + f.List = true + } + + if !p.peek(itemName) { + err = p.tokErr(`field type`) + return + } + f.Type = p.val(p.next()) + + if f.List { + if !p.peek(itemListClose) { + err = p.tokErr(`]`) + return + } + p.ignore() + } + + if p.peek(itemRequired) { + p.ignore() + f.Required = true + } + + for p.peek(itemDirective) { + p.ignore() + if f.Directives, err = p.parseDirective(f.Directives); err != nil { + return + } + } + t.Fields = append(t.Fields, f) + } +} + +func (p *Parser) tokErr(exp string) error { + item := p.items[p.pos+1] + return fmt.Errorf("unexpected token '%s', expecting '%s' (line: %d, pos: %d)", + string(item.val), exp, item.line, item.pos) +} + +func (p *Parser) eofErr(tok string) error { + item := p.items[p.pos+1] + return fmt.Errorf("invalid %[1]s: end reached before %[1]s was closed (line: %d, pos: %d)", + tok, item.line, item.pos) +} diff --git a/core/internal/graph/utils.go b/core/internal/graph/utils.go index b763ecc7..50470266 100644 --- a/core/internal/graph/utils.go +++ b/core/internal/graph/utils.go @@ -1,7 +1,9 @@ package graph import ( + "bytes" "errors" + "io" "strings" "text/scanner" ) @@ -15,8 +17,19 @@ func FastParse(gql string) (h FPInfo, err error) { if gql == "" { return h, errors.New("query missing or empty") } + return fastParse(strings.NewReader(gql)) +} + +func FastParseBytes(gql []byte) (h FPInfo, err error) { + if len(gql) == 0 { + return h, errors.New("query missing or empty") + } + return fastParse(bytes.NewReader(gql)) +} + +func fastParse(r io.Reader) (h FPInfo, err error) { var s scanner.Scanner - s.Init(strings.NewReader(gql)) + s.Init(r) s.Whitespace ^= 1 << '\n' // don't skip new lines comment := false diff --git a/core/internal/psql/mutate.go b/core/internal/psql/mutate.go index 6b13374a..5d49a028 100644 --- a/core/internal/psql/mutate.go +++ b/core/internal/psql/mutate.go @@ -523,7 +523,7 @@ func (c *compilerContext) renderMutateToRecordSet(m qcode.Mutate, n int) { if n != 0 { c.w.WriteString(`, `) } - if m.IsArray { + if m.Array { c.w.WriteString(`json_to_recordset`) } else { c.w.WriteString(`json_to_record`) diff --git a/core/internal/psql/psql_test.go b/core/internal/psql/psql_test.go index 1e784755..eb311aa1 100644 --- a/core/internal/psql/psql_test.go +++ b/core/internal/psql/psql_test.go @@ -1,6 +1,7 @@ package psql_test import ( + "encoding/json" "errors" "log" "os" @@ -136,19 +137,39 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func compileGQLToPSQL(t *testing.T, gql string, vars qcode.Variables, role string) { - if err := _compileGQLToPSQL(t, gql, vars, role); err != nil { +func compileGQLToPSQL(t *testing.T, gql string, + vars map[string]json.RawMessage, + role string) { + + var v json.RawMessage + var err error + + if v, err = json.Marshal(vars); err != nil { + t.Error(err) + } + + if err := _compileGQLToPSQL(t, gql, v, role); err != nil { t.Error(err) } } -func compileGQLToPSQLExpectErr(t *testing.T, gql string, vars qcode.Variables, role string) { - if err := _compileGQLToPSQL(t, gql, vars, role); err == nil { +func compileGQLToPSQLExpectErr(t *testing.T, gql string, + vars map[string]json.RawMessage, + role string) { + + var v json.RawMessage + var err error + + if v, err = json.Marshal(v); err != nil { + t.Error(err) + } + + if err := _compileGQLToPSQL(t, gql, v, role); err == nil { t.Error(errors.New("we were expecting an error")) } } -func _compileGQLToPSQL(t *testing.T, gql string, vars qcode.Variables, role string) error { +func _compileGQLToPSQL(t *testing.T, gql string, vars json.RawMessage, role string) error { for i := 0; i < 1000; i++ { qc, err := qcompile.Compile([]byte(gql), vars, role, "") if err != nil { diff --git a/core/internal/psql/query.go b/core/internal/psql/query.go index 6129a614..3f3539d1 100644 --- a/core/internal/psql/query.go +++ b/core/internal/psql/query.go @@ -288,7 +288,6 @@ func (c *compilerContext) renderSelect(sel *qcode.Select) { } } } - c.w.WriteString(`AS json `) // We manually insert the cursor values into row we're building outside diff --git a/core/internal/psql/query_test.go b/core/internal/psql/query_test.go index b8925695..4a3c3ac5 100644 --- a/core/internal/psql/query_test.go +++ b/core/internal/psql/query_test.go @@ -151,7 +151,7 @@ func withNestedWhere(t *testing.T) { }` vars := map[string]json.RawMessage{ - "email": json.RawMessage(`test@test.com`), + "email": json.RawMessage(`"test@test.com"`), } compileGQLToPSQL(t, gql, vars, "user") diff --git a/core/internal/qcode/exp.go b/core/internal/qcode/exp.go index f87950d0..71012eed 100644 --- a/core/internal/qcode/exp.go +++ b/core/internal/qcode/exp.go @@ -10,7 +10,7 @@ import ( ) func (co *Compiler) compileArgObj(edge string, - ti sdata.DBTable, st *util.StackInf, arg *graph.Arg, selID int32) (*Exp, bool, error) { + ti sdata.DBTable, st *util.StackInf, arg graph.Arg, selID int32) (*Exp, bool, error) { if arg.Val.Type != graph.NodeObj { return nil, false, fmt.Errorf("expecting an object") } @@ -370,10 +370,10 @@ func (ast *aexpst) processOpAndVal(av aexp, ex *Exp, node *graph.Node) (bool, er case "is_null": ex.Op = OpIsNull ex.Right.Val = node.Val - case "null_eq", "ndis", "not_distinct": + case "ndis", "not_distinct": ex.Op = OpNotDistinct ex.Right.Val = node.Val - case "null_neq", "dis", "distinct": + case "dis", "distinct": ex.Op = OpDistinct ex.Right.Val = node.Val default: diff --git a/core/internal/qcode/fields.go b/core/internal/qcode/fields.go index c5776a31..556666db 100644 --- a/core/internal/qcode/fields.go +++ b/core/internal/qcode/fields.go @@ -144,7 +144,7 @@ func (co *Compiler) compileChildColumns( return nil } -func (co *Compiler) compileFuncTableArg(sel *Select, arg *graph.Arg) error { +func (co *Compiler) compileFuncTableArg(sel *Select, arg graph.Arg) error { fn := sel.Ti.Func input, err := fn.GetInput(arg.Name) if err != nil { diff --git a/core/internal/qcode/mutate.go b/core/internal/qcode/mutate.go index ad0fa490..1da91ee7 100644 --- a/core/internal/qcode/mutate.go +++ b/core/internal/qcode/mutate.go @@ -211,9 +211,9 @@ func (co *Compiler) compileMutation(qc *QCode, role string) error { } type mData struct { - Data *graph.Node - IsJSON bool - IsArray bool + Data *graph.Node + IsJSON bool + Array bool } func parseDataValue(qc *QCode, actionVal *graph.Node, isJSON bool) (mData, error) { @@ -418,7 +418,7 @@ func (co *Compiler) processNestedMutations(ms *mState, m *Mutate, data *graph.No func (co *Compiler) processList(m Mutate) []Mutate { if m.IsJSON { - m.IsArray = m.Data.Type == graph.NodeList + m.Array = m.Data.Type == graph.NodeList m.Data = m.Data.Children[0] return []Mutate{m} } @@ -427,7 +427,7 @@ func (co *Compiler) processList(m Mutate) []Mutate { for i := range m.Data.Children { m1 := m m1.Data = m.Data.Children[i] - m1.IsArray = m1.Data.Type == graph.NodeList + m1.Array = m1.Data.Type == graph.NodeList m1.ID += int32(i) mList = append(mList, m1) } diff --git a/core/internal/qcode/qcode.go b/core/internal/qcode/qcode.go index d238baf9..ab50f0ce 100644 --- a/core/internal/qcode/qcode.go +++ b/core/internal/qcode/qcode.go @@ -75,6 +75,13 @@ type QCode struct { Script Script Validation Validation Typename bool + Query []byte + Fragments []Fragment +} + +type Fragment struct { + Name string + Value []byte } type Select struct { @@ -341,29 +348,47 @@ func NewCompiler(s *sdata.DBSchema, c Config) (*Compiler, error) { } func (co *Compiler) Compile( - query []byte, vars Variables, role, namespace string) (*QCode, error) { - var err error + query []byte, vars json.RawMessage, role, namespace string) (qc *QCode, err error) { - op, err := graph.Parse(query) + var op graph.Operation + op, err = graph.Parse(query) if err != nil { - return nil, err + return + } + + qc = &QCode{ + Name: op.Name, + SType: QTQuery, + Schema: co.s, + Query: op.Query, + Fragments: make([]Fragment, len(op.Frags)), + } + + if len(vars) != 0 { + qc.Vars = make(map[string]json.RawMessage) + + if err := json.Unmarshal(vars, &qc.Vars); err != nil { + return nil, fmt.Errorf("variables: %w", err) + } + } + + for i, f := range op.Frags { + qc.Fragments[i] = Fragment{Name: f.Name, Value: f.Value} } - qc := QCode{Name: op.Name, SType: QTQuery, Schema: co.s, Vars: vars} qc.Roots = qc.rootsA[:0] qc.Type = GetQType(op.Type) - if err := co.compileQuery(&qc, &op, role); err != nil { - return nil, err + if err = co.compileQuery(qc, &op, role); err != nil { + return } if qc.Type == QTMutation { - if err := co.compileMutation(&qc, role); err != nil { - return nil, err + if err = co.compileMutation(qc, role); err != nil { + return } } - - return &qc, nil + return } func (co *Compiler) compileQuery(qc *QCode, op *graph.Operation, role string) error { @@ -624,10 +649,10 @@ func (co *Compiler) setRelFilters(qc *QCode, sel *Select) { switch rel.Type { case sdata.RelOneToOne, sdata.RelOneToMany: - setFilter(&sel.Where, buildFilter(rel, pid)) + addAndFilter(&sel.Where, buildFilter(rel, pid)) case sdata.RelEmbedded: - setFilter(&sel.Where, buildFilter(rel, pid)) + addAndFilter(&sel.Where, buildFilter(rel, pid)) case sdata.RelPolymorphic: pid = qc.Selects[sel.ParentID].ParentID @@ -647,7 +672,7 @@ func (co *Compiler) setRelFilters(qc *QCode, sel *Select) { ex2.Right.Val = sel.Ti.Name ex.Children = []*Exp{ex1, ex2} - setFilter(&sel.Where, ex) + addAndFilter(&sel.Where, ex) case sdata.RelRecursive: rcte := "__rcte_" + rel.Right.Ti.Name @@ -735,7 +760,7 @@ func (co *Compiler) setRelFilters(qc *QCode, sel *Select) { } ex.Children = []*Exp{ex1, ex2, ex3} - setFilter(&sel.Where, ex) + addAndFilter(&sel.Where, ex) } } @@ -891,8 +916,7 @@ func (co *Compiler) addSeekPredicate(sel *Select) { or.Children = append(or.Children, and) } } - - setFilter(&sel.Where, or) + addAndFilter(&sel.Where, or) } func addFilters(qc *QCode, where *Filter, trv trval) bool { @@ -902,7 +926,7 @@ func addFilters(qc *QCode, where *Filter, trv trval) bool { case OpFalse: where.Exp = fil default: - setFilter(where, fil) + addAndFilter(where, fil) } return userNeeded } @@ -1024,9 +1048,7 @@ func (co *Compiler) compileSelectorDirectives2(qc *QCode, sel *Select, dirs []gr func (co *Compiler) compileArgs(sel *Select, args []graph.Arg, role string) error { var err error - for i := range args { - arg := &args[i] - + for _, arg := range args { switch arg.Name { case "id": err = co.compileArgID(sel, arg) @@ -1142,17 +1164,11 @@ func (co *Compiler) setMutationType(qc *QCode, op *graph.Operation, role string) } func (co *Compiler) compileDirectiveSchema(sel *Select, d *graph.Directive) error { - if len(d.Args) == 0 { - return fmt.Errorf("required argument 'name' missing") + arg, err := getArg(d.Args, "name", []graph.ParserType{graph.NodeStr}, true) + if err == nil { + sel.Schema = arg.Val.Val } - arg := d.Args[0] - - if ifNotArg(arg, graph.NodeStr) { - return argTypeErr("string") - } - - sel.Schema = arg.Val.Val - return nil + return err } func (co *Compiler) compileSelectDirectiveSkipInclude(skip bool, sel *Select, d *graph.Directive, role string) (err error) { @@ -1230,7 +1246,7 @@ func (co *Compiler) compileSkipIncludeFilter( selID int32, fil *Filter, arg graph.Arg, - role string) error { + role string) (err error) { if ifArg(arg, graph.NodeVar) { var ex *Exp @@ -1241,15 +1257,22 @@ func (co *Compiler) compileSkipIncludeFilter( } ex.Right.ValType = ValVar ex.Right.Val = arg.Val.Val - setFilter(fil, ex) - return nil + addAndFilter(fil, ex) + return } if ifArg(arg, graph.NodeObj) { + var ex *Exp + ex, err = co.compileFilter(sel, selID, arg, role) + if err != nil { + return + } if skip { - setFilter(fil, newExpOp(OpNot)) + addNotFilter(fil, ex) + } else { + addAndFilter(fil, ex) } - return co.compileAndSetFilter(sel, selID, fil, &arg, role) + return } return argErr("if", "variable or filter expression") } @@ -1508,7 +1531,12 @@ func (co *Compiler) compileDirectiveValidation(qc *QCode, d *graph.Directive) er return nil } -func (co *Compiler) compileArgFind(sel *Select, arg *graph.Arg) error { +func (co *Compiler) compileArgFind(sel *Select, arg graph.Arg) error { + err := validateArg(arg, []graph.ParserType{graph.NodeStr}) + if err != nil { + return err + } + // Only allow on recursive relationship selectors if sel.Rel.Type != sdata.RelRecursive { return fmt.Errorf("selector '%s' is not recursive", sel.FieldName) @@ -1520,18 +1548,16 @@ func (co *Compiler) compileArgFind(sel *Select, arg *graph.Arg) error { return nil } -func (co *Compiler) compileArgID(sel *Select, arg *graph.Arg) error { - node := arg.Val - +func (co *Compiler) compileArgID(sel *Select, arg graph.Arg) error { if sel.ParentID != -1 { return fmt.Errorf("can only be specified at the query root") } - if node.Type != graph.NodeNum && - node.Type != graph.NodeStr && - node.Type != graph.NodeVar { - return argTypeErr("number, string or variable") + err := validateArg(arg, []graph.ParserType{graph.NodeNum, graph.NodeStr, graph.NodeVar}) + if err != nil { + return err } + node := arg.Val if sel.Ti.PrimaryCol.Name == "" { return fmt.Errorf("no primary key column defined for '%s'", sel.Table) @@ -1563,7 +1589,7 @@ func (co *Compiler) compileArgID(sel *Select, arg *graph.Arg) error { return nil } -func (co *Compiler) compileArgSearch(sel *Select, arg *graph.Arg) error { +func (co *Compiler) compileArgSearch(sel *Select, arg graph.Arg) error { if len(sel.Ti.FullText) == 0 { switch co.s.DBType() { case "mysql": @@ -1573,8 +1599,9 @@ func (co *Compiler) compileArgSearch(sel *Select, arg *graph.Arg) error { } } - if arg.Val.Type != graph.NodeVar { - return argTypeErr("variable") + err := validateArg(arg, []graph.ParserType{graph.NodeStr, graph.NodeVar}) + if err != nil { + return err } ex := newExpOp(OpTsQuery) @@ -1582,22 +1609,32 @@ func (co *Compiler) compileArgSearch(sel *Select, arg *graph.Arg) error { ex.Right.Val = arg.Val.Val sel.addIArg(Arg{Name: arg.Name, Val: arg.Val.Val}) - setFilter(&sel.Where, ex) + addAndFilter(&sel.Where, ex) return nil } -func (co *Compiler) compileArgWhere(sel *Select, arg *graph.Arg, role string) error { - return co.compileAndSetFilter(sel, -1, &sel.Where, arg, role) -} +func (co *Compiler) compileArgWhere(sel *Select, arg graph.Arg, role string) (err error) { + err = validateArg(arg, []graph.ParserType{graph.NodeObj}) + if err != nil { + return + } -func (co *Compiler) compileArgOrderBy(sel *Select, arg *graph.Arg) error { - node := arg.Val + var ex *Exp + ex, err = co.compileFilter(sel, -1, arg, role) + if err != nil { + return + } + addAndFilter(&sel.Where, ex) + return +} - if node.Type != graph.NodeObj && - node.Type != graph.NodeVar { - return argTypeErr("object or variable") +func (co *Compiler) compileArgOrderBy(sel *Select, arg graph.Arg) error { + err := validateArg(arg, []graph.ParserType{graph.NodeObj, graph.NodeVar}) + if err != nil { + return err } + node := arg.Val cm := make(map[string]struct{}) for _, ob := range sel.OrderBy { @@ -1761,22 +1798,24 @@ func compileOrderBy(sel *Select, return nil } -func (co *Compiler) compileArgArgs(sel *Select, arg *graph.Arg) error { +func (co *Compiler) compileArgArgs(sel *Select, arg graph.Arg) error { if sel.Ti.Type != "function" { return fmt.Errorf("'%s' is not a db function", sel.Ti.Name) } + err := validateArg(arg, []graph.ParserType{graph.NodeList}) + if err != nil { + return err + } + fn := sel.Ti.Func + if len(fn.Inputs) == 0 { return fmt.Errorf("db function '%s' does not have any arguments", sel.Ti.Name) } node := arg.Val - if node.Type != graph.NodeList { - return argErr("args", "list") - } - for i, n := range node.Children { var err error a := Arg{DType: fn.Inputs[i].Type} @@ -1819,12 +1858,12 @@ func toOrder(val string) (Order, error) { } } -func (co *Compiler) compileArgDistinctOn(sel *Select, arg *graph.Arg) error { - node := arg.Val - - if node.Type != graph.NodeList && node.Type != graph.NodeStr { - return fmt.Errorf("expecting a list of strings or just a string") +func (co *Compiler) compileArgDistinctOn(sel *Select, arg graph.Arg) error { + err := validateArg(arg, []graph.ParserType{graph.NodeList, graph.NodeStr}) + if err != nil { + return err } + node := arg.Val if node.Type == graph.NodeStr { if col, err := sel.Ti.GetColumn(node.Val); err == nil { @@ -1855,12 +1894,12 @@ func (co *Compiler) compileArgDistinctOn(sel *Select, arg *graph.Arg) error { return nil } -func (co *Compiler) compileArgLimit(sel *Select, arg *graph.Arg) error { - node := arg.Val - - if node.Type != graph.NodeNum && node.Type != graph.NodeVar { - return argTypeErr("number or variable") +func (co *Compiler) compileArgLimit(sel *Select, arg graph.Arg) error { + err := validateArg(arg, []graph.ParserType{graph.NodeNum, graph.NodeVar}) + if err != nil { + return err } + node := arg.Val switch node.Type { case graph.NodeNum: @@ -1879,12 +1918,12 @@ func (co *Compiler) compileArgLimit(sel *Select, arg *graph.Arg) error { return nil } -func (co *Compiler) compileArgOffset(sel *Select, arg *graph.Arg) error { - node := arg.Val - - if node.Type != graph.NodeNum && node.Type != graph.NodeVar { - return argTypeErr("number or variable") +func (co *Compiler) compileArgOffset(sel *Select, arg graph.Arg) error { + err := validateArg(arg, []graph.ParserType{graph.NodeNum, graph.NodeVar}) + if err != nil { + return err } + node := arg.Val switch node.Type { case graph.NodeNum: @@ -1903,7 +1942,7 @@ func (co *Compiler) compileArgOffset(sel *Select, arg *graph.Arg) error { return nil } -func (co *Compiler) compileArgFirstLast(sel *Select, arg *graph.Arg, order Order) error { +func (co *Compiler) compileArgFirstLast(sel *Select, arg graph.Arg, order Order) error { if err := co.compileArgLimit(sel, arg); err != nil { return err } @@ -1916,10 +1955,14 @@ func (co *Compiler) compileArgFirstLast(sel *Select, arg *graph.Arg, order Order return nil } -func (co *Compiler) compileArgAfterBefore(sel *Select, arg *graph.Arg, pt PagingType) error { +func (co *Compiler) compileArgAfterBefore(sel *Select, arg graph.Arg, pt PagingType) error { + err := validateArg(arg, []graph.ParserType{graph.NodeVar}) + if err != nil { + return err + } node := arg.Val - if node.Type != graph.NodeVar || node.Val != "cursor" { + if node.Val != "cursor" { return fmt.Errorf("value for argument '%s' must be a variable named $cursor", arg.Name) } sel.Paging.Type = pt @@ -1947,21 +1990,22 @@ func (co *Compiler) setOrderByColName(ti sdata.DBTable, ob *OrderBy, node *graph return nil } -func (co *Compiler) compileAndSetFilter(sel *Select, selID int32, fil *Filter, arg *graph.Arg, role string) error { +func (co *Compiler) compileFilter(sel *Select, selID int32, arg graph.Arg, role string) (ex *Exp, err error) { st := util.NewStackInf() - ex, nu, err := co.compileArgObj(sel.Table, sel.Ti, st, arg, selID) + var nu bool + + ex, nu, err = co.compileArgObj(sel.Table, sel.Ti, st, arg, selID) if err != nil { - return err + return } if nu && role == "anon" { sel.SkipRender = SkipTypeUserNeeded } - setFilter(fil, ex) - return nil + return } -func setFilter(fil *Filter, ex *Exp) { +func addAndFilter(fil *Filter, ex *Exp) { if fil.Exp == nil { fil.Exp = ex return @@ -1971,14 +2015,30 @@ func setFilter(fil *Filter, ex *Exp) { // add a new `and` exp and hook the above saved exp pointer a child // we don't want to modify an exp object thats common (from filter config) - if ow.Op != OpAnd && ow.Op != OpOr && ow.Op != OpNot { - fil.Exp = newExpOp(OpAnd) - fil.Exp.Children = fil.Exp.childrenA[:2] - fil.Exp.Children[0] = ex - fil.Exp.Children[1] = ow - } else { - fil.Exp.Children = append(fil.Exp.Children, ex) + fil.Exp = newExpOp(OpAnd) + fil.Exp.Children = fil.Exp.childrenA[:2] + fil.Exp.Children[0] = ex + fil.Exp.Children[1] = ow +} + +func addNotFilter(fil *Filter, ex *Exp) { + ex1 := newExpOp(OpNot) + ex1.Children = ex1.childrenA[:1] + ex1.Children[0] = ex + + if fil.Exp == nil { + fil.Exp = ex1 + return } + // save exiting exp pointer (could be a common one from filter config) + ow := fil.Exp + + // add a new `and` exp and hook the above saved exp pointer a child + // we don't want to modify an exp object thats common (from filter config) + fil.Exp = newExpOp(OpAnd) + fil.Exp.Children = fil.Exp.childrenA[:2] + fil.Exp.Children[0] = ex1 + fil.Exp.Children[1] = ow } func compileFilter(s *sdata.DBSchema, ti sdata.DBTable, filter []string, isJSON bool) (*Exp, bool, error) { @@ -2053,6 +2113,36 @@ func compileFilter(s *sdata.DBSchema, ti sdata.DBTable, filter []string, isJSON // return b.String() // } +func getArg(args []graph.Arg, name string, validTypes []graph.ParserType, + required bool) (arg graph.Arg, err error) { + for _, a := range args { + if a.Name != name { + continue + } + if err = validateArg(a, validTypes); err != nil { + return + } + return a, nil + } + if required { + err = fmt.Errorf("required argument '%s' missing", name) + } + return +} + +func validateArg(arg graph.Arg, validTypes []graph.ParserType) (err error) { + for _, vt := range validTypes { + if arg.Val.Type == vt { + return + } + } + return argErr(arg.Name, argTypes(validTypes)) +} + +func argExists(arg graph.Arg) bool { + return arg.Val != nil +} + func ifArgList(arg graph.Arg, lty graph.ParserType) bool { return arg.Val.Type == graph.NodeList && len(arg.Val.Children) != 0 && @@ -2067,20 +2157,25 @@ func ifNotArg(arg graph.Arg, ty graph.ParserType) bool { return arg.Val.Type != ty } -// func ifArgVal(arg graph.Arg, val string) bool { -// return arg.Val.Val == val -// } - func ifNotArgVal(arg graph.Arg, val string) bool { return arg.Val.Val != val } -func argErr(name, ty string) error { - return fmt.Errorf("value for argument '%s' must be a %s", name, ty) +func argTypes(types []graph.ParserType) string { + var sb strings.Builder + lastIndex := len(types) - 1 + for i, t := range types { + if i == lastIndex { + sb.WriteString(" or " + t.String()) + } else if i != 0 { + sb.WriteString(", " + t.String()) + } + } + return sb.String() } -func argTypeErr(ty string) error { - return fmt.Errorf("value must be a %s", ty) +func argErr(name, ty string) error { + return fmt.Errorf("value for argument '%s' must be a %s", name, ty) } func dbArgErr(name, ty, db string) error { @@ -2108,55 +2203,3 @@ func (s *Script) HasReqFn() bool { func (s *Script) HasRespFn() bool { return s.SC.HasResponseFn() } - -/* -func (qc *QCode) getVar(name string, vt ValType) (string, error) { - val, ok := qc.Vars[name] - if !ok { - return "", fmt.Errorf("variable '%s' not defined", name) - } - k := string(val) - if k == "null" { - return "", nil - } - switch vt { - case ValStr: - if k != "" && k[0] == '"' { - return k[1:(len(k) - 1)], nil - } - case ValNum: - if k != "" && ((k[0] >= '0' && k[0] <= '9') || k[0] == '-') { - return k, nil - } - case ValBool: - if strings.EqualFold(k, "true") || strings.EqualFold(k, "false") { - return k, nil - } - case ValList: - if k != "" && k[0] == '[' { - return k, nil - } - case ValObj: - if k != "" && k[0] == '{' { - return k, nil - } - } - - var vts string - switch vt { - case ValStr: - vts = "string" - case ValNum: - vts = "number" - case ValBool: - vts = "boolean" - case ValList: - vts = "list" - case ValObj: - vts = "object" - } - - return "", fmt.Errorf("variable '%s' must be a %s and not '%s'", - name, vts, k) -} -*/ diff --git a/core/internal/qcode/qcode_test.go b/core/internal/qcode/qcode_test.go index c91c037f..4275df96 100644 --- a/core/internal/qcode/qcode_test.go +++ b/core/internal/qcode/qcode_test.go @@ -78,9 +78,8 @@ func TestCompile3(t *testing.T) { return } - vars := map[string]json.RawMessage{ - "data": json.RawMessage(` { "name": "my_name", "description": "my_desc" }`), - } + vars := json.RawMessage(` + { "data": { "name": "my_name", "description": "my_desc" } }`) _, err = qc.Compile([]byte(` mutation { @@ -102,10 +101,10 @@ func TestCompile4(t *testing.T) { } }` - vars := map[string]json.RawMessage{ - "email": json.RawMessage(`"reannagreenholt@orn.com"`), - "full_name": json.RawMessage(`"Flo Barton"`), - } + vars := json.RawMessage(`{ + "email": "reannagreenholt@orn.com", + "full_name": "Flo Barton" + }`) qc, _ := qcode.NewCompiler(dbs, qcode.Config{}) _, err := qc.Compile([]byte(gql), vars, "user", "") diff --git a/core/internal/qcode/schema.go b/core/internal/qcode/schema.go new file mode 100644 index 00000000..ec029d57 --- /dev/null +++ b/core/internal/qcode/schema.go @@ -0,0 +1,235 @@ +package qcode + +import ( + "fmt" + "strconv" + "strings" + "unicode" + + "github.com/dosco/graphjin/v2/core/internal/graph" + "github.com/dosco/graphjin/v2/core/internal/sdata" +) + +type Schema struct { + Type string + Version int + Schema string + Columns []sdata.DBColumn + Functions []sdata.DBFunction +} + +func ParseSchema(b []byte) (ds Schema, err error) { + var s graph.Schema + s, err = graph.ParseSchema(b) + if err != nil { + return + } + ds.Type = s.Type + ds.Schema = s.Schema + + if v, err1 := strconv.Atoi(s.Version); err == nil { + ds.Version = v + } else if s.Version != "" && err1 != nil { + err = err1 + return + } + + for _, t := range s.Types { + var ti typeInfo + + ti, err = parseTypeDirectives(t.Directives) + if err != nil { + err = fmt.Errorf("%s: %w", t.Name, err) + return + } + if ti.Schema == "" { + ti.Schema = s.Schema + } + + if ti.ReturnType != "" { + df := sdata.DBFunction{ + Schema: ti.Schema, + Name: t.Name, + Type: ti.ReturnType, + } + if err = parseTFieldsFunction(&df, t.Fields); err != nil { + break + } + ds.Functions = append(ds.Functions, df) + + } else { + var cols []sdata.DBColumn + cols, err = parseTFieldsColumns(ti.Schema, t.Name, t.Fields) + if err != nil { + break + } + ds.Columns = append(ds.Columns, cols...) + } + if err != nil { + err = fmt.Errorf("%s: %w", t.Name, err) + } + } + return +} + +func parseTFieldsColumns(tableSchema, tableName string, fields []graph.TField) ( + cols []sdata.DBColumn, err error) { + var dir tfieldInfo + for i, f := range fields { + dir, err = parseTFieldDirectives(f.Type, f.Directives) + if err != nil { + return + } + col := sdata.DBColumn{ + ID: int32(i), + Schema: tableSchema, + Table: tableName, + Name: f.Name, + Type: pascalToSnakeSpace(f.Type), + Array: f.List, + NotNull: f.Required, + PrimaryKey: dir.ID, + UniqueKey: dir.Unique, + Blocked: dir.Blocked, + FKeySchema: dir.RelatedSchema, + FKeyTable: dir.RelatedType, + FKeyCol: dir.RelatedField, + } + cols = append(cols, col) + } + return +} + +func parseTFieldsFunction(fn *sdata.DBFunction, fields []graph.TField) ( + err error) { + for i, f := range fields { + var dir tfieldInfo + dir, err = parseTFieldDirectives(f.Type, f.Directives) + if err != nil { + return + } + p := sdata.DBFuncParam{ + ID: i, + Name: f.Name, + Type: pascalToSnakeSpace(f.Type), + } + switch { + case dir.Input: + fn.Inputs = append(fn.Inputs, p) + case dir.Output: + fn.Outputs = append(fn.Outputs, p) + default: + err = fmt.Errorf("%s: @input or @output directive required", p.Name) + return + } + } + return +} + +type typeInfo struct { + Schema string + ReturnType string +} + +func parseTypeDirectives(dir []graph.Directive) (ti typeInfo, err error) { + for _, d := range dir { + var arg graph.Arg + switch d.Name { + case "schema": + arg, err = getArg(d.Args, "name", + []graph.ParserType{graph.NodeStr, graph.NodeLabel}, true) + if err != nil { + break + } + ti.Schema = arg.Val.Val + + case "function": + arg, err = getArg(d.Args, "return_type", + []graph.ParserType{graph.NodeStr, graph.NodeLabel}, true) + if err != nil { + break + } + ti.ReturnType = arg.Val.Val + } + if err != nil { + err = fmt.Errorf("type: %w", err) + return + } + } + return +} + +type tfieldInfo struct { + ID bool + Unique bool + Blocked bool + RelatedType string + RelatedField string + RelatedSchema string + Input bool + Output bool +} + +func parseTFieldDirectives(ft string, dir []graph.Directive) (tfi tfieldInfo, err error) { + for _, d := range dir { + var arg graph.Arg + switch d.Name { + case "id": + tfi.ID = true + + case "unique": + tfi.Unique = true + + case "blocked": + tfi.Blocked = true + + case "relation": + arg, err = getArg(d.Args, "type", + []graph.ParserType{graph.NodeStr, graph.NodeLabel}, true) + if err != nil { + break + } + tfi.RelatedType = arg.Val.Val + + arg, err = getArg(d.Args, "field", + []graph.ParserType{graph.NodeStr, graph.NodeLabel}, + (ft != "Json")) + if err != nil { + break + } + if argExists(arg) { + tfi.RelatedField = arg.Val.Val + } + + arg, err = getArg(d.Args, "schema", + []graph.ParserType{graph.NodeStr, graph.NodeLabel}, false) + if err != nil { + break + } + if argExists(arg) { + tfi.RelatedSchema = arg.Val.Val + } + case "input": + tfi.Input = true + + case "output": + tfi.Output = true + } + if err != nil { + err = fmt.Errorf("type field: %w", err) + return + } + } + return +} + +func pascalToSnakeSpace(s string) string { + var result string + for i, r := range s { + if i > 0 && unicode.IsUpper(r) { + result += " " + } + result += strings.ToLower(string(r)) + } + return result +} diff --git a/core/internal/sdata/sql/mysql_columns.sql b/core/internal/sdata/sql/mysql_columns.sql index 7a096cd9..a5933743 100644 --- a/core/internal/sdata/sql/mysql_columns.sql +++ b/core/internal/sdata/sql/mysql_columns.sql @@ -1,67 +1,88 @@ -SELECT - col.table_schema as "schema", +SELECT col.table_schema as "schema", col.table_name as "table", col.column_name as "column", col.data_type as "type", - (CASE - WHEN col.is_nullable = 'YES' THEN TRUE - ELSE FALSE - END) AS not_null, + ( + CASE + WHEN col.is_nullable = 'YES' THEN TRUE + ELSE FALSE + END + ) AS not_null, false AS primary_key, false AS unique_key, - (CASE - WHEN col.data_type = 'ARRAY' THEN TRUE - ELSE FALSE - END) AS is_array, - (CASE - WHEN stat.index_type = 'FULLTEXT' THEN TRUE - ELSE FALSE - END) AS full_text, + ( + CASE + WHEN col.data_type = 'ARRAY' THEN TRUE + ELSE FALSE + END + ) AS is_array, + ( + CASE + WHEN stat.index_type = 'FULLTEXT' THEN TRUE + ELSE FALSE + END + ) AS full_text, '' AS foreignkey_schema, '' AS foreignkey_table, '' AS foreignkey_column -FROM - information_schema.columns col -LEFT JOIN information_schema.statistics stat ON col.table_schema = stat.table_schema +FROM information_schema.columns col + LEFT JOIN information_schema.statistics stat ON col.table_schema = stat.table_schema AND col.table_name = stat.table_name - AND col.column_name = stat.column_name - AND stat.index_type = 'FULLTEXT' -WHERE - col.table_schema NOT IN ('_graphjin', 'information_schema', 'performance_schema', 'mysql', 'sys') -UNION -SELECT - kcu.table_schema as "schema", + AND col.column_name = stat.column_name + AND stat.index_type = 'FULLTEXT' +WHERE col.table_schema NOT IN ( + '_graphjin', + 'information_schema', + 'performance_schema', + 'mysql', + 'sys' + ) +UNION +SELECT kcu.table_schema as "schema", kcu.table_name as "table", kcu.column_name as "column", '' as "type", false AS not_null, - (CASE - WHEN tc.constraint_type = 'PRIMARY KEY' THEN TRUE - ELSE FALSE - END) AS primary_key, - (CASE - WHEN tc.constraint_type = 'UNIQUE' THEN TRUE - ELSE FALSE - END) AS unique_key, + ( + CASE + WHEN tc.constraint_type = 'PRIMARY KEY' THEN TRUE + ELSE FALSE + END + ) AS primary_key, + ( + CASE + WHEN tc.constraint_type = 'UNIQUE' THEN TRUE + ELSE FALSE + END + ) AS unique_key, false AS is_array, false AS full_text, - (CASE - WHEN tc.constraint_type = 'FOREIGN KEY' THEN kcu.referenced_table_schema - ELSE '' - END) AS foreignkey_schema, - (CASE - WHEN tc.constraint_type = 'FOREIGN KEY' THEN kcu.referenced_table_name - ELSE '' - END) AS foreignkey_table, - (CASE - WHEN tc.constraint_type = 'FOREIGN KEY' THEN kcu.referenced_column_name - ELSE '' - END) AS foreignkey_column -FROM - information_schema.key_column_usage kcu -JOIN - information_schema.table_constraints tc ON kcu.table_schema = tc.table_schema + ( + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN kcu.referenced_table_schema + ELSE '' + END + ) AS foreignkey_schema, + ( + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN kcu.referenced_table_name + ELSE '' + END + ) AS foreignkey_table, + ( + CASE + WHEN tc.constraint_type = 'FOREIGN KEY' THEN kcu.referenced_column_name + ELSE '' + END + ) AS foreignkey_column +FROM information_schema.key_column_usage kcu + JOIN information_schema.table_constraints tc ON kcu.table_schema = tc.table_schema AND kcu.table_name = tc.table_name - AND kcu.constraint_name = tc.constraint_name -WHERE - kcu.constraint_schema NOT IN ('_graphjin', 'information_schema', 'performance_schema', 'mysql', 'sys'); \ No newline at end of file + AND kcu.constraint_name = tc.constraint_name +WHERE kcu.constraint_schema NOT IN ( + '_graphjin', + 'information_schema', + 'performance_schema', + 'mysql', + 'sys' + ); \ No newline at end of file diff --git a/core/internal/sdata/sql/mysql_functions.sql b/core/internal/sdata/sql/mysql_functions.sql index f0b592f1..802cb65a 100644 --- a/core/internal/sdata/sql/mysql_functions.sql +++ b/core/internal/sdata/sql/mysql_functions.sql @@ -1,17 +1,27 @@ -SELECT - r.specific_name as func_id, - r.routine_schema as func_schema, - r.routine_name as func_name, - (CASE WHEN r.data_type = 'USER-DEFINED' THEN 'record' ELSE r.data_type END) as data_type, - p.ordinal_position as param_id, +SELECT r.specific_name as func_id, + r.routine_schema as func_schema, + r.routine_name as func_name, + ( + CASE + WHEN r.data_type = 'USER-DEFINED' THEN 'record' + ELSE r.data_type + END + ) as data_type, + p.ordinal_position as param_id, COALESCE(p.parameter_name, '') as param_name, - p.data_type as param_type, - COALESCE(p.parameter_mode, '') as param_kind -FROM - information_schema.routines r -RIGHT JOIN - information_schema.parameters p - ON (r.specific_name = p.specific_name AND r.specific_name = p.specific_name) -WHERE - r.routine_type = 'FUNCTION' - AND p.specific_schema NOT IN ('_graphjin', 'information_schema', 'performance_schema', 'mysql', 'sys'); \ No newline at end of file + p.data_type as param_type, + COALESCE(p.parameter_mode, '') as param_kind +FROM information_schema.routines r + RIGHT JOIN information_schema.parameters p ON ( + r.specific_name = p.specific_name + AND r.specific_name = p.specific_name + ) +WHERE r.routine_type = 'FUNCTION' + AND r.data_type != 'void' + AND p.specific_schema NOT IN ( + '_graphjin', + 'information_schema', + 'performance_schema', + 'mysql', + 'sys' + ); \ No newline at end of file diff --git a/core/internal/sdata/sql/postgres_columns.sql b/core/internal/sdata/sql/postgres_columns.sql index 3908229d..766306cd 100644 --- a/core/internal/sdata/sql/postgres_columns.sql +++ b/core/internal/sdata/sql/postgres_columns.sql @@ -1,52 +1,79 @@ -SELECT - n.nspname as "schema", +SELECT n.nspname as "schema", c.relname as "table", - f.attname AS "column", - pg_catalog.format_type(f.atttypid,f.atttypmod) AS "type", - f.attnotnull AS not_null, - (CASE - WHEN co.contype = ('p'::char) THEN true - ELSE false - END) AS primary_key, - (CASE - WHEN co.contype = ('u'::char) THEN true - ELSE false - END) AS unique_key, - (CASE - WHEN f.attndims != 0 THEN true - WHEN right(pg_catalog.format_type(f.atttypid,f.atttypmod), 2) = '[]' THEN true - ELSE false - END) AS is_array, - (CASE - WHEN pg_catalog.format_type(f.atttypid,f.atttypmod) = 'tsvector' THEN TRUE - ELSE FALSE - END) AS full_text, - (CASE - WHEN co.contype = ('f'::char) - THEN (SELECT n.nspname FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.oid = co.confrelid) - ELSE ''::text - END) AS foreignkey_schema, - (CASE - WHEN co.contype = ('f'::char) - THEN (SELECT relname FROM pg_class WHERE oid = co.confrelid) - ELSE ''::text - END) AS foreignkey_table, - (CASE - WHEN co.contype = ('f'::char) - THEN (SELECT f.attname FROM pg_attribute f WHERE f.attnum = co.confkey[1] and f.attrelid = co.confrelid) - ELSE ''::text - END) AS foreignkey_column -FROM - pg_attribute f - JOIN pg_class c ON c.oid = f.attrelid - LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum - LEFT JOIN pg_namespace n ON n.oid = c.relnamespace - LEFT JOIN pg_constraint co ON co.conrelid = c.oid AND f.attnum = ANY (co.conkey) -WHERE - c.relkind IN ('r', 'v', 'm', 'f', 'p') - AND n.nspname NOT IN ('_graphjin', 'information_schema', 'pg_catalog') + f.attname AS "column", + pg_catalog.format_type(f.atttypid, f.atttypmod) AS "type", + f.attnotnull AS not_null, + ( + CASE + WHEN co.contype = ('p'::char) THEN true + ELSE false + END + ) AS primary_key, + ( + CASE + WHEN co.contype = ('u'::char) THEN true + ELSE false + END + ) AS unique_key, + ( + CASE + WHEN f.attndims != 0 THEN true + WHEN right( + pg_catalog.format_type(f.atttypid, f.atttypmod), + 2 + ) = '[]' THEN true + ELSE false + END + ) AS is_array, + ( + CASE + WHEN pg_catalog.format_type(f.atttypid, f.atttypmod) = 'tsvector' THEN TRUE + ELSE FALSE + END + ) AS full_text, + ( + CASE + WHEN co.contype = ('f'::char) THEN ( + SELECT n.nspname + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE c.oid = co.confrelid + ) + ELSE ''::text + END + ) AS foreignkey_schema, + ( + CASE + WHEN co.contype = ('f'::char) THEN ( + SELECT relname + FROM pg_class + WHERE oid = co.confrelid + ) + ELSE ''::text + END + ) AS foreignkey_table, + ( + CASE + WHEN co.contype = ('f'::char) THEN ( + SELECT f.attname + FROM pg_attribute f + WHERE f.attnum = co.confkey [1] + and f.attrelid = co.confrelid + ) + ELSE ''::text + END + ) AS foreignkey_column +FROM pg_attribute f + JOIN pg_class c ON c.oid = f.attrelid + LEFT JOIN pg_attrdef d ON d.adrelid = c.oid + AND d.adnum = f.attnum + LEFT JOIN pg_namespace n ON n.oid = c.relnamespace + LEFT JOIN pg_constraint co ON co.conrelid = c.oid + AND f.attnum = ANY (co.conkey) +WHERE c.relkind IN ('r', 'v', 'm', 'f', 'p') + AND n.nspname NOT IN ('_graphjin', 'information_schema', 'pg_catalog') AND c.relname != 'schema_version' AND f.attnum > 0 AND f.attisdropped = false -ORDER BY - f.attrelid, f.attnum ASC; \ No newline at end of file +ORDER BY f.attrelid, + f.attnum ASC; \ No newline at end of file diff --git a/core/internal/sdata/sql/postgres_functions.sql b/core/internal/sdata/sql/postgres_functions.sql index 062376ce..6b3a2fe8 100644 --- a/core/internal/sdata/sql/postgres_functions.sql +++ b/core/internal/sdata/sql/postgres_functions.sql @@ -1,35 +1,55 @@ -SELECT - r.specific_name as func_id, - r.routine_schema as func_schema, - r.routine_name as func_name, - (CASE WHEN r.data_type = 'USER-DEFINED' THEN 'record' ELSE r.data_type END) as data_type, - p.ordinal_position as param_id, +SELECT r.specific_name as func_id, + r.routine_schema as func_schema, + r.routine_name as func_name, + ( + CASE + WHEN r.data_type = 'USER-DEFINED' THEN 'record' + ELSE r.data_type + END + ) as data_type, + p.ordinal_position as param_id, COALESCE(p.parameter_name, '') as param_name, - p.data_type as param_type, - COALESCE(p.parameter_mode, '') as param_kind -FROM - information_schema.routines r -RIGHT JOIN - information_schema.parameters p - ON (r.specific_name = p.specific_name AND r.specific_name = p.specific_name) -WHERE - r.routine_type = 'FUNCTION' - AND r.specific_schema NOT IN ('_graphjin', 'information_schema', 'performance_schema', 'pg_catalog', 'mysql', 'sys') + p.data_type as param_type, + COALESCE(p.parameter_mode, '') as param_kind +FROM information_schema.routines r + RIGHT JOIN information_schema.parameters p ON ( + r.specific_name = p.specific_name + AND r.specific_name = p.specific_name + ) +WHERE r.routine_type = 'FUNCTION' + AND r.data_type != 'void' + AND r.specific_schema NOT IN ( + '_graphjin', + 'information_schema', + 'performance_schema', + 'pg_catalog', + 'mysql', + 'sys' + ) UNION -SELECT - r.specific_name as func_id, - r.routine_schema as func_schema, - r.routine_name as func_name, - 'record' as data_type, - a.ordinal_position as param_id, - COALESCE(a.attribute_name, CAST(a.ordinal_position as CHAR(3))) as param_name, - a.data_type as param_type, - 'OUT' as param_kind -FROM - information_schema.routines r -RIGHT JOIN - information_schema.attributes a - ON (r.data_type = 'USER-DEFINED' AND a.udt_schema = r.type_udt_schema AND a.udt_name = r.type_udt_name) -WHERE - r.routine_type = 'FUNCTION' - AND r.specific_schema NOT IN ('_graphjin', 'information_schema', 'performance_schema', 'pg_catalog', 'mysql', 'sys'); +SELECT r.specific_name as func_id, + r.routine_schema as func_schema, + r.routine_name as func_name, + 'record' as data_type, + a.ordinal_position as param_id, + COALESCE( + a.attribute_name, + CAST(a.ordinal_position as CHAR(3)) + ) as param_name, + a.data_type as param_type, + 'OUT' as param_kind +FROM information_schema.routines r + RIGHT JOIN information_schema.attributes a ON ( + r.data_type = 'USER-DEFINED' + AND a.udt_schema = r.type_udt_schema + AND a.udt_name = r.type_udt_name + ) +WHERE r.routine_type = 'FUNCTION' + AND r.specific_schema NOT IN ( + '_graphjin', + 'information_schema', + 'performance_schema', + 'pg_catalog', + 'mysql', + 'sys' + ); \ No newline at end of file diff --git a/core/internal/sdata/sql/postgres_info.sql b/core/internal/sdata/sql/postgres_info.sql index 9680c854..90fe36a3 100644 --- a/core/internal/sdata/sql/postgres_info.sql +++ b/core/internal/sdata/sql/postgres_info.sql @@ -1,4 +1,3 @@ -SELECT - CAST(current_setting('server_version_num') AS integer) as db_version, +SELECT CAST(current_setting('server_version_num') AS integer) as db_version, current_schema() as db_schema, current_database() as db_name; \ No newline at end of file diff --git a/core/internal/sdata/strings.go b/core/internal/sdata/strings.go index a8687134..956ac112 100644 --- a/core/internal/sdata/strings.go +++ b/core/internal/sdata/strings.go @@ -31,9 +31,9 @@ func (fn DBFunction) String() string { for _, v := range fn.Inputs { if v.Name == "" { - sb.WriteString(fmt.Sprintf("%d: %v [array:%t]", v.ID, v.Type, v.IsArray)) + sb.WriteString(fmt.Sprintf("%d: %v [array:%t]", v.ID, v.Type, v.Array)) } else { - sb.WriteString(fmt.Sprintf("%s: %v [array:%t]", v.Name, v.Type, v.IsArray)) + sb.WriteString(fmt.Sprintf("%s: %v [array:%t]", v.Name, v.Type, v.Array)) } } @@ -41,9 +41,9 @@ func (fn DBFunction) String() string { for _, v := range fn.Outputs { if v.Name == "" { - sb.WriteString(fmt.Sprintf("%d: %v [array:%t]", v.ID, v.Type, v.IsArray)) + sb.WriteString(fmt.Sprintf("%d: %v [array:%t]", v.ID, v.Type, v.Array)) } else { - sb.WriteString(fmt.Sprintf("%s: %v [array:%t]", v.Name, v.Type, v.IsArray)) + sb.WriteString(fmt.Sprintf("%s: %v [array:%t]", v.Name, v.Type, v.Array)) } } diff --git a/core/internal/sdata/tables.go b/core/internal/sdata/tables.go index 52600c2e..22e6f14e 100644 --- a/core/internal/sdata/tables.go +++ b/core/internal/sdata/tables.go @@ -18,7 +18,7 @@ type DBInfo struct { Tables []DBTable Functions []DBFunction - VTables []VirtualTable + VTables []VirtualTable `json:"-"` colMap map[string]int tableMap map[string]int hash int @@ -44,10 +44,6 @@ type VirtualTable struct { FKeyColumn string } -type st struct { - schema, table string -} - func GetDBInfo( db *sql.DB, dbType string, @@ -92,18 +88,6 @@ func GetDBInfo( return nil, err } - h := fnv.New128() - hv := fmt.Sprintf("%s%d%s%s", dbType, dbVersion, dbSchema, dbName) - h.Write([]byte(hv)) - - for _, c := range cols { - h.Write([]byte(c.String())) - } - - for _, fn := range funcs { - h.Write([]byte(fn.String())) - } - di := NewDBInfo( dbType, dbVersion, @@ -113,8 +97,6 @@ func GetDBInfo( funcs, blockList) - di.hash = h.Size() - return di, nil } @@ -137,13 +119,17 @@ func NewDBInfo( tableMap: make(map[string]int), } + type st struct { + schema string + table string + } + tm := make(map[st][]DBColumn) - for i := range cols { - c := cols[i] + for i, c := range cols { di.colMap[(c.Schema + ":" + c.Table + ":" + c.Name)] = i - k1 := st{c.Schema, c.Table} - tm[k1] = append(tm[k1], c) + k := st{c.Schema, c.Table} + tm[k] = append(tm[k], c) } for k, tcols := range tm { @@ -173,6 +159,19 @@ func NewDBInfo( di.AddTable(t) } + h := fnv.New128() + hv := fmt.Sprintf("%s%d%s%s", dbType, dbVersion, dbSchema, dbName) + h.Write([]byte(hv)) + + for _, c := range cols { + h.Write([]byte(c.String())) + } + + for _, fn := range funcs { + h.Write([]byte(fn.String())) + } + + di.hash = h.Size() return di } @@ -338,10 +337,10 @@ type DBFunction struct { } type DBFuncParam struct { - ID int - Name string - Type string - IsArray bool + ID int + Name string + Type string + Array bool } func DiscoverFunctions(db *sql.DB, dbtype string, blockList []string) ([]DBFunction, error) { @@ -386,7 +385,7 @@ func DiscoverFunctions(db *sql.DB, dbtype string, blockList []string) ([]DBFunct param := DBFuncParam{ID: pid, Name: pn, Type: pt} if strings.HasSuffix(pt, "[]") { - param.IsArray = true + param.Array = true } switch pk { diff --git a/core/internal/sdata/test_dbinfo.go b/core/internal/sdata/test_dbinfo.go index db1f324d..0b0c4808 100644 --- a/core/internal/sdata/test_dbinfo.go +++ b/core/internal/sdata/test_dbinfo.go @@ -2,60 +2,84 @@ package sdata func GetTestDBInfo() *DBInfo { columns := [][]DBColumn{ - []DBColumn{ - DBColumn{Schema: "public", Table: "customers", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, - DBColumn{Schema: "public", Table: "customers", Name: "user_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "users", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "customers", Name: "product_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "products", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "customers", Name: "vip", Type: "boolean", NotNull: true, PrimaryKey: false, UniqueKey: false}}, - []DBColumn{ - DBColumn{Schema: "public", Table: "users", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, - DBColumn{Schema: "public", Table: "users", Name: "full_name", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "phone", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "avatar", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "email", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "encrypted_password", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "reset_password_token", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "reset_password_sent_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "remember_created_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "users", Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}}, - []DBColumn{ - DBColumn{Schema: "public", Table: "products", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, - DBColumn{Schema: "public", Table: "products", Name: "name", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "products", Name: "description", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "products", Name: "price", Type: "numeric(7,2)", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "products", Name: "user_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "users", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "products", Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "products", Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "products", Name: "tsv", Type: "tsvector", NotNull: false, PrimaryKey: false, UniqueKey: false, FullText: true}, - DBColumn{Schema: "public", Table: "products", Name: "tags", Type: "text[]", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "tags", FKeyCol: "slug", Array: true}, - DBColumn{Schema: "public", Table: "products", Name: "tag_count", Type: "json", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "tag_count", FKeyCol: ""}}, - []DBColumn{ - DBColumn{Schema: "public", Table: "purchases", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, - DBColumn{Schema: "public", Table: "purchases", Name: "customer_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "customers", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "purchases", Name: "product_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "products", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "purchases", Name: "sale_type", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "purchases", Name: "quantity", Type: "integer", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "purchases", Name: "due_date", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "purchases", Name: "returned", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}}, - []DBColumn{ - DBColumn{Schema: "public", Table: "tags", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, - DBColumn{Schema: "public", Table: "tags", Name: "name", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "tags", Name: "slug", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}}, - []DBColumn{ - DBColumn{Schema: "public", Table: "tag_count", Name: "tag_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "tags", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "tag_count", Name: "count", Type: "int", NotNull: false, PrimaryKey: false, UniqueKey: false}}, - []DBColumn{ - DBColumn{Schema: "public", Table: "notifications", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, - DBColumn{Schema: "public", Table: "notifications", Name: "verb", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "notifications", Name: "subject_type", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}, - DBColumn{Schema: "public", Table: "notifications", Name: "subject_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false}}, - []DBColumn{ - DBColumn{Schema: "public", Table: "comments", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, - DBColumn{Schema: "public", Table: "comments", Name: "product_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "products", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "comments", Name: "commenter_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "users", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "comments", Name: "reply_to_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "comments", FKeyCol: "id"}, - DBColumn{Schema: "public", Table: "comments", Name: "body", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}}, + { + {Schema: "public", Table: "customers", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, + {Schema: "public", Table: "customers", Name: "user_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "users", FKeyCol: "id"}, + {Schema: "public", Table: "customers", Name: "product_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "products", FKeyCol: "id"}, + {Schema: "public", Table: "customers", Name: "vip", Type: "boolean", NotNull: true, PrimaryKey: false, UniqueKey: false}}, + { + {Schema: "public", Table: "users", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, + {Schema: "public", Table: "users", Name: "full_name", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "phone", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "avatar", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "email", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "encrypted_password", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "reset_password_token", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "reset_password_sent_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "remember_created_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "users", Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}}, + { + {Schema: "public", Table: "products", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, + {Schema: "public", Table: "products", Name: "name", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "products", Name: "description", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "products", Name: "price", Type: "numeric(7,2)", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "products", Name: "user_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "users", FKeyCol: "id"}, + {Schema: "public", Table: "products", Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "products", Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "products", Name: "tsv", Type: "tsvector", NotNull: false, PrimaryKey: false, UniqueKey: false, FullText: true}, + {Schema: "public", Table: "products", Name: "tags", Type: "text[]", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "tags", FKeyCol: "slug", Array: true}, + {Schema: "public", Table: "products", Name: "tag_count", Type: "json", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "tag_count", FKeyCol: ""}}, + { + {Schema: "public", Table: "purchases", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, + {Schema: "public", Table: "purchases", Name: "customer_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "customers", FKeyCol: "id"}, + {Schema: "public", Table: "purchases", Name: "product_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "products", FKeyCol: "id"}, + {Schema: "public", Table: "purchases", Name: "sale_type", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "purchases", Name: "quantity", Type: "integer", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "purchases", Name: "due_date", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "purchases", Name: "returned", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}}, + { + {Schema: "public", Table: "tags", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, + {Schema: "public", Table: "tags", Name: "name", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "tags", Name: "slug", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}}, + { + {Schema: "public", Table: "tag_count", Name: "tag_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "tags", FKeyCol: "id"}, + {Schema: "public", Table: "tag_count", Name: "count", Type: "int", NotNull: false, PrimaryKey: false, UniqueKey: false}}, + { + {Schema: "public", Table: "notifications", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, + {Schema: "public", Table: "notifications", Name: "verb", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "notifications", Name: "subject_type", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}, + {Schema: "public", Table: "notifications", Name: "subject_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false}}, + { + {Schema: "public", Table: "comments", Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: true}, + {Schema: "public", Table: "comments", Name: "product_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "products", FKeyCol: "id"}, + {Schema: "public", Table: "comments", Name: "commenter_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "users", FKeyCol: "id"}, + {Schema: "public", Table: "comments", Name: "reply_to_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeySchema: "public", FKeyTable: "comments", FKeyCol: "id"}, + {Schema: "public", Table: "comments", Name: "body", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false}}, + } + + fn := []DBFunction{ + { + Schema: "public", + Name: "get_top_products", + Type: "record", + Agg: false, + Inputs: []DBFuncParam{ + {ID: 1, Name: "n", Type: "integer", Array: false}}, + Outputs: []DBFuncParam{ + {ID: 2, Name: "id", Type: "bigint", Array: false}, + {ID: 3, Name: "name", Type: "bigint", Array: false}}, + }, + { + Schema: "public", + Name: "text2score", + Type: "numeric", + Agg: false, + Inputs: []DBFuncParam{ + {ID: 1, Name: "text", Type: "text", Array: false}}, + Outputs: []DBFuncParam{ + {ID: 2, Name: "score", Type: "bigint", Array: false}}, + }, } var cols []DBColumn @@ -73,6 +97,7 @@ func GetTestDBInfo() *DBInfo { di := NewDBInfo("", 110000, "public", "db", cols, nil, nil) di.VTables = vt + di.Functions = fn return di } diff --git a/core/introspec.go b/core/introspec.go index 6571135d..d591f034 100644 --- a/core/introspec.go +++ b/core/introspec.go @@ -123,17 +123,12 @@ type intro struct { exptNeeded map[string]bool } -// func (g *GraphJin) Introspection(query string) (*Result, error) { -// gj := g.Load().(*graphjin) -// return gj.introspection(query) -// } - -func (gj *graphjin) introspection(query string) ([]byte, error) { +func (gj *graphjin) introspection(query []byte) ([]byte, error) { engine, err := gj.newGraphQLEngine() if err != nil { return nil, err } - r := engine.ServeGraphQL(&graphql.Request{Query: query}) + r := engine.ServeGraphQL(&graphql.Request{Query: string(query)}) if err := r.Error(); err != nil { return nil, err } diff --git a/core/prepare.go b/core/prepare.go index af8485c4..37193a22 100644 --- a/core/prepare.go +++ b/core/prepare.go @@ -7,17 +7,8 @@ import ( "strings" "github.com/dosco/graphjin/v2/core/internal/allow" - "github.com/dosco/graphjin/v2/core/internal/qcode" ) -type queryReq struct { - op qcode.QType - ns string - name string - query []byte - vars []byte -} - // nolint: errcheck func (gj *graphjin) prepareRoleStmt() error { if !gj.abacEnabled { diff --git a/core/remote_join.go b/core/remote_join.go index 3bbd24c1..14f089e2 100644 --- a/core/remote_join.go +++ b/core/remote_join.go @@ -11,48 +11,48 @@ import ( "github.com/dosco/graphjin/v2/internal/jsn" ) -func (c *gcontext) execRemoteJoin(ctx context.Context, res queryResp) (queryResp, error) { - var err error - sel := res.qc.st.qc.Selects - +func (s *gstate) execRemoteJoin(c context.Context) (err error) { // fetch the field name used within the db response json // that are used to mark insertion points and the mapping between // those field names and their select objects - fids, sfmap, err := c.parentFieldIds(sel, res.qc.st.qc.Remotes) + var fids [][]byte + var sfmap map[string]*qcode.Select + + fids, sfmap, err = s.parentFieldIds() if err != nil { - return res, err + return } // fetch the field values of the marked insertion points // these values contain the id to be used with fetching remote data - from := jsn.Get(res.data, fids) + from := jsn.Get(s.data, fids) var to []jsn.Field if len(from) == 0 { - return res, errors.New("something wrong no remote ids found in db response") + err = errors.New("something wrong no remote ids found in db response") + return } - to, err = c.resolveRemotes(ctx, from, sel, sfmap) + to, err = s.resolveRemotes(c, from, sfmap) if err != nil { - return res, err + return } var ob bytes.Buffer - err = jsn.Replace(&ob, res.data, from, to) + err = jsn.Replace(&ob, s.data, from, to) if err != nil { - return res, err + return } - res.data = ob.Bytes() - - return res, nil + s.data = ob.Bytes() + return } -func (c *gcontext) resolveRemotes( +func (s *gstate) resolveRemotes( ctx context.Context, from []jsn.Field, - sel []qcode.Select, sfmap map[string]*qcode.Select) ([]jsn.Field, error) { + selects := s.cs.st.qc.Selects // replacement data for the marked insertion points // key and value will be replaced by whats below @@ -65,15 +65,15 @@ func (c *gcontext) resolveRemotes( for i, id := range from { // use the json key to find the related Select object - s, ok := sfmap[string(id.Key)] + sel, ok := sfmap[string(id.Key)] if !ok { return nil, fmt.Errorf("invalid remote field key") } - p := sel[s.ParentID] + p := selects[sel.ParentID] // then use the Table name in the Select and it's parent // to find the resolver to use for this relationship - r, ok := c.gj.rmap[(s.Table + p.Table)] + r, ok := s.gj.rmap[(sel.Table + p.Table)] if !ok { return nil, fmt.Errorf("no resolver found") } @@ -83,18 +83,18 @@ func (c *gcontext) resolveRemotes( return nil, fmt.Errorf("invalid remote field id") } - go func(n int, id []byte, s *qcode.Select) { + go func(n int, id []byte, sel *qcode.Select) { defer wg.Done() //st := time.Now() - ctx1, span := c.gj.spanStart(ctx, "Execute Remote Request") + ctx1, span := s.gj.spanStart(ctx, "Execute Remote Request") b, err := r.Fn.Resolve(ctx1, ResolverReq{ - ID: string(id), Sel: s, Log: c.gj.log, ReqConfig: c.rc}) + ID: string(id), Sel: sel, Log: s.gj.log, ReqConfig: s.r.rc}) if err != nil { - cerr = fmt.Errorf("%s: %s", s.Table, err) + cerr = fmt.Errorf("%s: %s", sel.Table, err) span.Error(cerr) } span.End() @@ -109,10 +109,10 @@ func (c *gcontext) resolveRemotes( var ob bytes.Buffer - if len(s.Fields) != 0 { - err = jsn.Filter(&ob, b, fieldsToList(s.Fields)) + if len(sel.Fields) != 0 { + err = jsn.Filter(&ob, b, fieldsToList(sel.Fields)) if err != nil { - cerr = fmt.Errorf("%s: %w", s.Table, err) + cerr = fmt.Errorf("%s: %w", sel.Table, err) return } @@ -120,16 +120,16 @@ func (c *gcontext) resolveRemotes( ob.WriteString("null") } - to[n] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()} - }(i, id, s) + to[n] = jsn.Field{Key: []byte(sel.FieldName), Value: ob.Bytes()} + }(i, id, sel) } wg.Wait() - return to, cerr } -func (c *gcontext) parentFieldIds(sel []qcode.Select, remotes int32) ( - [][]byte, map[string]*qcode.Select, error) { +func (s *gstate) parentFieldIds() ([][]byte, map[string]*qcode.Select, error) { + selects := s.cs.st.qc.Selects + remotes := s.cs.st.qc.Remotes // list of keys (and it's related value) to extract from // the db json response @@ -139,18 +139,16 @@ func (c *gcontext) parentFieldIds(sel []qcode.Select, remotes int32) ( // object sm := make(map[string]*qcode.Select, remotes) - for i := range sel { - s := &sel[i] - - if s.SkipRender != qcode.SkipTypeRemote { + for i, sel := range selects { + if sel.SkipRender != qcode.SkipTypeRemote { continue } - p := sel[s.ParentID] + p := selects[sel.ParentID] - if r, ok := c.gj.rmap[(s.Table + p.Table)]; ok { + if r, ok := s.gj.rmap[(sel.Table + p.Table)]; ok { fm = append(fm, r.IDField) - sm[string(r.IDField)] = s + sm[string(r.IDField)] = &selects[i] } } return fm, sm, nil diff --git a/core/schema.go b/core/schema.go new file mode 100644 index 00000000..f90660d5 --- /dev/null +++ b/core/schema.go @@ -0,0 +1,128 @@ +package core + +import ( + "fmt" + "io" + "regexp" + "strings" + "text/tabwriter" + "text/template" + + "github.com/dosco/graphjin/v2/core/internal/sdata" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +const schemaTemplate = ` +# dbinfo:{{if .Type}}{{ .Type }}{{else}}postgres{{end}},{{- .Version }},{{- .Schema }} + +{{ define "schema_directive"}} +{{- if and (ne .Schema "public") (ne .Schema "")}} @schema(name: {{ .Schema }}){{end}} +{{- end}} + +{{- define "relation_directive"}} +{{- if (ne .FKeyTable "")}} @relation(type: {{ .FKeyTable }} +{{- if (ne .FKeyCol "")}}, field: {{ .FKeyCol }}{{end -}} +{{- if and (ne .FKeySchema "public") (ne .FKeySchema "")}}, schema: {{ .FKeySchema }}{{end -}}) +{{- end}} +{{- end}} + +{{- define "function_directive"}} +{{- " @function" }} +{{- if (ne .Type "")}}(return_type: {{ .Type }}){{end}} +{{- end}} + +{{- define "column_type"}} +{{- $var := .Type|dbtype }} +{{- $type := (index $var 0)|pascal }} +{{- if .Array}}[{{ $type }}]{{else}}{{ $type }}{{end}} +{{- if .NotNull}}!{{end}} +{{- "\t" }} +{{- if ne (index $var 1) ""}} @type(args: {{ (index $var 1) | printf "%q" }}){{end}} +{{- template "relation_directive" .}} +{{- end}} + +{{- define "column"}} +{{ "\t" }} +{{- .Name }}: +{{- "\t"}} +{{- template "column_type" .}} +{{- if .PrimaryKey}} @id{{end}} +{{- if .UniqueKey}} @unique{{end}} +{{- if .FullText}} @search{{end}} +{{- if .Blocked}} @blocked{{end}} +{{- end}} + +{{- define "func_args"}} +{{ "\t" }} +{{- .Name }}: +{{- "\t"}} +{{- $var := .Type|dbtype }} +{{- (index $var 0)|pascal }} +{{- if .Array}}[]{{end}} +{{- "\t"}} +{{- if ne (index $var 1) ""}} @type_args({{ (index $var 1) }}){{end}} +{{- end -}} + +{{range .Tables -}} +type {{.Name}} +{{- template "schema_directive" .}} { +{{- range .Columns}}{{template "column" .}}{{end}} +} + +{{end -}} + +{{range .Functions -}} +type {{.Name}} +{{- template "schema_directive" .}} +{{- template "function_directive" .}} { +{{- range .Inputs}}{{template "func_args" .}}{{"\t"}}@input{{end}} +{{- range .Outputs}}{{template "func_args" .}}{{"\t"}}@output{{end}} +} + +{{end -}} +` + +func writeSchema(s *sdata.DBInfo, out io.Writer) (err error) { + fn := template.FuncMap{ + "pascal": toPascalCase, + "dbtype": parseDBType, + } + + tmpl, err := template. + New("schema"). + Funcs(fn). + Parse(schemaTemplate) + + if err != nil { + return err + } + + w := tabwriter.NewWriter(out, 2, 2, 2, ' ', 0) + err = tmpl.Execute(w, s) + if err != nil { + return err + } + return +} + +func toPascalCase(text string) string { + var sb strings.Builder + c := cases.Title(language.English) + for _, v := range strings.Fields(text) { + sb.WriteString(c.String(v)) + } + return sb.String() +} + +var dbTypeRe = regexp.MustCompile(`([a-zA-Z ]+)(\((.+)\))?`) + +func parseDBType(name string) (res [2]string, err error) { + v := dbTypeRe.FindStringSubmatch(name) + if len(v) == 4 { + res = [2]string{v[1], v[3]} + } else { + err = fmt.Errorf("invalid db type: %s", name) + } + return +} diff --git a/core/schema_test.go b/core/schema_test.go new file mode 100644 index 00000000..edd8e3e7 --- /dev/null +++ b/core/schema_test.go @@ -0,0 +1,37 @@ +package core + +import ( + "bytes" + "fmt" + "testing" + + "github.com/dosco/graphjin/v2/core/internal/qcode" + "github.com/dosco/graphjin/v2/core/internal/sdata" +) + +func TestCreateSchema(t *testing.T) { + var buf bytes.Buffer + + di1 := sdata.GetTestDBInfo() + if err := writeSchema(di1, &buf); err != nil { + t.Fatal(err) + } + + ds, err := qcode.ParseSchema(buf.Bytes()) + if err != nil { + t.Fatal(err) + } + + di2 := sdata.NewDBInfo(ds.Type, + ds.Version, + ds.Schema, + "", + ds.Columns, + ds.Functions, + nil) + + if di1.Hash() != di2.Hash() { + t.Fatal(fmt.Errorf("schema hashes do not match: expected %d got %d", + di1.Hash(), di2.Hash())) + } +} diff --git a/core/script.go b/core/script.go index 2f979e4c..43a507a8 100644 --- a/core/script.go +++ b/core/script.go @@ -52,9 +52,11 @@ func (gj *graphjin) readScriptSource(name string) (string, error) { return string(src), nil } -func (c *gcontext) scriptCallReq(ctx context.Context, qc *qcode.QCode, - vars map[string]interface{}, role string) ( - []byte, error) { +func (s *gstate) scriptCallReq(ctx context.Context, + qc *qcode.QCode, + vars map[string]interface{}, + role string) ([]byte, error) { + defer func() { // nolint: errcheck recover() @@ -65,8 +67,8 @@ func (c *gcontext) scriptCallReq(ctx context.Context, qc *qcode.QCode, userID = v } - ctx1, span := c.gj.spanStart(ctx, "Execute Request Script") - gfn := c.newGraphQLFunc(ctx1, role) + ctx1, span := s.gj.spanStart(ctx, "Execute Request Script") + gfn := s.newGraphQLFunc(ctx1) val := qc.Script.SC.RequestFn(ctx1, vars, role, userID, gfn) if val == nil { @@ -79,22 +81,21 @@ func (c *gcontext) scriptCallReq(ctx context.Context, qc *qcode.QCode, return json.Marshal(val) } -func (c *gcontext) scriptCallResp(ctx context.Context, qc *qcode.QCode, - data []byte, role string) (_ []byte, err error) { +func (s *gstate) scriptCallResp(c context.Context) (err error) { defer func() { // nolint: errcheck recover() }() rj := make(map[string]interface{}) - if len(data) != 0 { - if err := json.Unmarshal(data, &rj); err != nil { - return nil, err + if len(s.data) != 0 { + if err = json.Unmarshal(s.data, &rj); err != nil { + return } } var userID interface{} - if v := ctx.Value(UserIDKey); v != nil { + if v := c.Value(UserIDKey); v != nil { userID = v } @@ -103,21 +104,22 @@ func (c *gcontext) scriptCallResp(ctx context.Context, qc *qcode.QCode, recover() }() - ctx1, span := c.gj.spanStart(ctx, "Execute Response Script") - gfn := c.newGraphQLFunc(ctx1, role) + c1, span := s.gj.spanStart(c, "Execute Response Script") + gfn := s.newGraphQLFunc(c1) - val := qc.Script.SC.ReponseFn(ctx1, rj, role, userID, gfn) + val := s.cs.st.qc.Script.SC.ReponseFn(c1, rj, s.role, userID, gfn) if val == nil { - err := errors.New("error excuting script") + err = errors.New("error excuting script") span.Error(err) - return data, nil + return } span.End() - return json.Marshal(val) + s.data, err = json.Marshal(val) + return } -func (c *gcontext) newGraphQLFunc(ctx context.Context, role string) func(string, map[string]interface{}, map[string]string) map[string]interface{} { +func (s *gstate) newGraphQLFunc(c context.Context) func(string, map[string]interface{}, map[string]string) map[string]interface{} { return func( query string, vars map[string]interface{}, @@ -128,43 +130,33 @@ func (c *gcontext) newGraphQLFunc(ctx context.Context, role string) func(string, if err != nil { panic(err) } - op := qcode.GetQTypeByName(h.Operation) - name := h.Name - qreq := queryReq{ - op: op, - name: name, - query: []byte(query), - } - - ct := gcontext{ - gj: c.gj, - rc: c.rc, - op: op, - name: name, - } + r := s.gj.newGraphqlReq(s.r.rc, + h.Operation, + h.Name, + []byte(query), + nil) if len(vars) != 0 { - if qreq.vars, err = json.Marshal(vars); err != nil { + if r.vars, err = json.Marshal(vars); err != nil { panic(fmt.Errorf("variables: %s", err)) } } - var r1 string + s := newGState(s.gj, r, s.role) - if v, ok := opt["role"]; ok && len(v) != 0 { - r1 = v - } else { - r1 = role + if v, ok := opt["role"]; ok && v != "" { + s.role = v } - qres, err := ct.execQuery(ctx, qreq, r1) + err = s.compileAndExecuteWrapper(c) + if err != nil { panic(err) } jres := make(map[string]interface{}) - if err = json.Unmarshal(qres.data, &jres); err != nil { + if err = json.Unmarshal(s.data, &jres); err != nil { panic(fmt.Errorf("json: %s", err)) } diff --git a/core/subs.go b/core/subs.go index 4b04a3d7..da0fcab2 100644 --- a/core/subs.go +++ b/core/subs.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/dosco/graphjin/v2/core/internal/allow" "github.com/dosco/graphjin/v2/core/internal/graph" "github.com/dosco/graphjin/v2/core/internal/qcode" "github.com/rs/xid" @@ -29,11 +30,9 @@ var ( ) type sub struct { - ns string - name string - role string - qc *queryComp - js json.RawMessage + k string + s gstate + js json.RawMessage add chan *Member del chan *Member @@ -85,40 +84,33 @@ func (g *GraphJin) Subscribe( c context.Context, query string, vars json.RawMessage, - rc *ReqConfig) (*Member, error) { - var err error + rc *ReqConfig) (m *Member, err error) { - h, err := graph.FastParse(query) - if err != nil { - return nil, err + // get the name, query vars + var h graph.FPInfo + if h, err = graph.FastParse(query); err != nil { + return } - op := qcode.GetQTypeByName(h.Operation) - name := h.Name gj := g.Load().(*graphjin) - if gj.prod && !gj.conf.DisableAllowList { - item, err := gj.allowList.GetByName(name, gj.prod) - if err != nil { - return nil, err - } - op = qcode.GetQTypeByName(item.Operation) - query = item.Query - } - - m, err := gj.subscribeWithOpName(c, op, name, query, vars, rc) - if err != nil { - return nil, err - } + // create the request object + r := gj.newGraphqlReq(rc, "subscription", h.Name, nil, vars) - if !gj.prod { - err := gj.saveToAllowList(nil, query, m.ns) + // if prod fetch query from allow list + if gj.prod { + var item allow.Item + item, err = gj.allowList.GetByName(h.Name, gj.prod) if err != nil { - return nil, err + return } + r.Set(item) + } else { + r.query = []byte(query) } - return m, err + m, err = gj.subscribe(c, r) + return } // SubscribeByName is similar to the Subscribe function except that queries saved @@ -127,38 +119,35 @@ func (g *GraphJin) SubscribeByName( c context.Context, name string, vars json.RawMessage, - rc *ReqConfig) (*Member, error) { + rc *ReqConfig) (m *Member, err error) { gj := g.Load().(*graphjin) - item, err := gj.allowList.GetByName(name, gj.prod) + + var item allow.Item + item, err = gj.allowList.GetByName(name, gj.prod) if err != nil { - return nil, err + return } - op := qcode.GetQTypeByName(item.Operation) - query := item.Query + r := gj.newGraphqlReq(rc, "subscription", name, nil, vars) + r.Set(item) - return gj.subscribeWithOpName(c, op, name, query, vars, rc) + m, err = gj.subscribe(c, r) + return } -func (gj *graphjin) subscribeWithOpName( - c context.Context, - op qcode.QType, - name string, - query string, - vars json.RawMessage, - rc *ReqConfig) (*Member, error) { +func (gj *graphjin) subscribe(c context.Context, r graphqlReq) ( + m *Member, err error) { - if op != qcode.QTSubscription { + if r.op != qcode.QTSubscription { return nil, errors.New("subscription: not a subscription query") } - if name == "" { - h := sha256.Sum256([]byte(query)) - name = hex.EncodeToString(h[:]) + if r.name == "" { + h := sha256.Sum256([]byte(r.query)) + r.name = hex.EncodeToString(h[:]) } var role string - var err error if v, ok := c.Value(UserRoleKey).(string); ok { role = v @@ -172,37 +161,33 @@ func (gj *graphjin) subscribeWithOpName( } if role == "user" && gj.abacEnabled { - if role, err = gj.executeRoleQuery(c, nil, vars, gj.pf, rc); err != nil { - return nil, err + role, err = gj.executeRoleQuery(c, nil, r.vars, r.rc) + if err != nil { + return } } - ns := gj.namespace - if rc != nil && rc.ns != nil { - ns = *rc.ns - } - - v, _ := gj.subs.LoadOrStore((ns + name + role), &sub{ - ns: ns, - name: name, - role: role, + k := (r.ns + r.name + role) + v, _ := gj.subs.LoadOrStore(k, &sub{ + k: k, + s: newGState(gj, r, role), add: make(chan *Member), del: make(chan *Member), updt: make(chan mmsg, 10), }) - s := v.(*sub) + sub := v.(*sub) - s.Do(func() { - err = gj.newSub(c, s, query, vars, rc) + sub.Do(func() { + err = gj.initSub(c, sub) }) if err != nil { - gj.subs.Delete((name + role)) - return nil, err + gj.subs.Delete(k) + return } - args, err := gj.argList(c, s.qc.st.md, vars, gj.pf, rc) - if err != nil { + var args args + if args, err = sub.s.argListVars(c, r.vars); err != nil { return nil, err } @@ -214,62 +199,47 @@ func (gj *graphjin) subscribeWithOpName( } } - m := &Member{ - ns: ns, + m = &Member{ + ns: r.ns, id: xid.New(), Result: make(chan *Result, 10), - sub: s, + sub: sub, vl: args.values, params: params, cindx: args.cindx, } - m.mm, err = gj.subFirstQuery(s, m, params) + m.mm, err = gj.subFirstQuery(sub, m, params) if err != nil { return nil, err } - s.add <- m - - return m, nil + sub.add <- m + return } -func (gj *graphjin) newSub(c context.Context, - s *sub, query string, vars json.RawMessage, rc *ReqConfig) error { - var err error - - qr := queryReq{ - ns: s.ns, - op: qcode.QTSubscription, - name: s.name, - query: []byte(query), - vars: vars, - } - - if s.qc, err = gj.compileQuery(qr, s.role); err != nil { - return err +func (gj *graphjin) initSub(c context.Context, sub *sub) (err error) { + if err = sub.s.compile(); err != nil { + return } - if !gj.prod && !gj.conf.DisableAllowList { - err := gj.allowList.Set( - nil, - query, - qr.ns) - + if !gj.prod { + err = gj.saveToAllowList(sub.s.cs.st.qc, nil, sub.s.r.ns) if err != nil { - return err + return } } - if len(s.qc.st.md.Params()) != 0 { - s.qc.st.sql = renderSubWrap(s.qc.st, gj.schema.DBType()) + if len(sub.s.cs.st.md.Params()) != 0 { + sub.s.cs.st.sql = renderSubWrap(sub.s.cs.st, gj.schema.DBType()) } - go gj.subController(s) - return nil + go gj.subController(sub) + return } -func (gj *graphjin) subController(s *sub) { - defer gj.subs.Delete((s.name + s.role)) +func (gj *graphjin) subController(sub *sub) { + // remove subscription if controller exists + defer gj.subs.Delete(sub.k) ps := gj.conf.SubsPollDuration if ps < minPollDuration { @@ -278,26 +248,26 @@ func (gj *graphjin) subController(s *sub) { for { select { - case m := <-s.add: - if err := s.addMember(m); err != nil { + case m := <-sub.add: + if err := sub.addMember(m); err != nil { gj.log.Printf(errSubs, "add-sub", err) return } - case m := <-s.del: - s.deleteMember(m) - if len(s.ids) == 0 { + case m := <-sub.del: + sub.deleteMember(m) + if len(sub.ids) == 0 { return } - case msg := <-s.updt: - if err := s.updateMember(msg); err != nil { + case msg := <-sub.updt: + if err := sub.updateMember(msg); err != nil { gj.log.Printf(errSubs, "update-sub", err) return } case <-time.After(ps): - s.fanOutJobs(gj) + sub.fanOutJobs(gj) } } } @@ -393,7 +363,7 @@ func (s *sub) fanOutJobs(gj *graphjin) { } } -func (gj *graphjin) subCheckUpdates(s *sub, mv mval, start int) { +func (gj *graphjin) subCheckUpdates(sub *sub, mv mval, start int) { // Do not use the `mval` embedded inside sub since // its not thread safe use the copy `mv mval`. @@ -412,7 +382,7 @@ func (gj *graphjin) subCheckUpdates(s *sub, mv mval, start int) { end = start + (len(mv.ids) - start) } - hasParams := len(s.qc.st.md.Params()) != 0 + hasParams := len(sub.s.cs.st.md.Params()) != 0 var rows *sql.Rows var err error @@ -430,16 +400,15 @@ func (gj *graphjin) subCheckUpdates(s *sub, mv mval, start int) { params = renderJSONArray(mv.params[start:end]) } - err = retryOperation(c, func() error { + err = retryOperation(c, func() (err1 error) { if hasParams { //nolint: sqlclosecheck - rows, err = gj.db.QueryContext(c, s.qc.st.sql, params) + rows, err1 = gj.db.QueryContext(c, sub.s.cs.st.sql, params) } else { //nolint: sqlclosecheck - rows, err = gj.db.QueryContext(c, s.qc.st.sql) + rows, err1 = gj.db.QueryContext(c, sub.s.cs.st.sql) } - - return err + return }) if err != nil { @@ -461,18 +430,18 @@ func (gj *graphjin) subCheckUpdates(s *sub, mv mval, start int) { i++ if hasParams { - gj.subNotifyMember(s, mv, j, js) + gj.subNotifyMember(sub, mv, j, js) continue } for k := start; k < end; k++ { - gj.subNotifyMember(s, mv, k, js) + gj.subNotifyMember(sub, mv, k, js) } - s.js = js + sub.js = js } } -func (gj *graphjin) subFirstQuery(s *sub, m *Member, params json.RawMessage) (mmsg, error) { +func (gj *graphjin) subFirstQuery(sub *sub, m *Member, params json.RawMessage) (mmsg, error) { c := context.Background() // when params are not available we use a more optimized @@ -483,22 +452,22 @@ func (gj *graphjin) subFirstQuery(s *sub, m *Member, params json.RawMessage) (mm var mm mmsg var err error - if s.js != nil { - js = s.js + if sub.js != nil { + js = sub.js } else { - err := retryOperation(c, func() error { + err := retryOperation(c, func() (err1 error) { switch { case params != nil: - err = gj.db. - QueryRowContext(c, s.qc.st.sql, renderJSONArray([]json.RawMessage{params})). + err1 = gj.db. + QueryRowContext(c, sub.s.cs.st.sql, renderJSONArray([]json.RawMessage{params})). Scan(&js) default: - err = gj.db. - QueryRowContext(c, s.qc.st.sql). + err1 = gj.db. + QueryRowContext(c, sub.s.cs.st.sql). Scan(&js) } - return err + return }) if err != nil { @@ -506,7 +475,7 @@ func (gj *graphjin) subFirstQuery(s *sub, m *Member, params json.RawMessage) (mm } } - mm, err = gj.subNotifyMemberEx(s, + mm, err = gj.subNotifyMemberEx(sub, [32]byte{}, m.cindx, m.id, @@ -527,7 +496,7 @@ func (gj *graphjin) subNotifyMember(s *sub, mv mval, j int, js json.RawMessage) } } -func (gj *graphjin) subNotifyMemberEx(s *sub, +func (gj *graphjin) subNotifyMemberEx(sub *sub, dh [32]byte, cindx int, id xid.ID, rc chan *Result, js json.RawMessage, update bool) (mmsg, error) { mm := mmsg{id: id} @@ -559,14 +528,14 @@ func (gj *graphjin) subNotifyMemberEx(s *sub, } if update { - s.updt <- mm + sub.updt <- mm } res := &Result{ op: qcode.QTQuery, - name: s.name, - sql: s.qc.st.sql, - role: s.qc.st.role.Name, + name: sub.s.r.name, + sql: sub.s.cs.st.sql, + role: sub.s.cs.st.role, Data: ejs, } diff --git a/core/wasm.go b/core/wasm.go index 061b60c9..38624f0c 100644 --- a/core/wasm.go +++ b/core/wasm.go @@ -8,7 +8,7 @@ import ( "net/http" ) -func (gj *graphjin) introspection(query string) ([]byte, error) { +func (gj *graphjin) introspection(query []byte) ([]byte, error) { return nil, errors.New("introspection not supported") } diff --git a/examples/nodejs/package-lock.json b/examples/nodejs/package-lock.json index 714a73af..0776ede4 100644 --- a/examples/nodejs/package-lock.json +++ b/examples/nodejs/package-lock.json @@ -17,17 +17,6 @@ "path": "^0.12.7" } }, - "../..": { - "version": "2.0.12", - "hasInstallScript": true, - "license": "Apache-2.0", - "devDependencies": { - "fs-extra": "^11.1.0" - } - }, - "../../wasm": { - "extraneous": true - }, "node_modules/accepts": { "version": "1.3.8", "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz", @@ -289,8 +278,10 @@ } }, "node_modules/graphjin": { - "resolved": "../..", - "link": true + "version": "2.0.19", + "resolved": "file:../..", + "hasInstallScript": true, + "license": "Apache-2.0" }, "node_modules/has": { "version": "1.0.3", @@ -1115,10 +1106,7 @@ } }, "graphjin": { - "version": "file:../..", - "requires": { - "fs-extra": "^11.1.0" - } + "version": "2.0.19" }, "has": { "version": "1.0.3", diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 1a18f967..0124ab7a 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -40,7 +40,6 @@ func Cmd() { rootCmd.AddCommand(initCmd()) rootCmd.AddCommand(deployCmd()) rootCmd.AddCommand(dbCmd()) - rootCmd.AddCommand(upgradeCmd()) if v := cmdSecrets(); v != nil { rootCmd.AddCommand(v) diff --git a/internal/cmd/cmd_upgrade.go b/internal/cmd/cmd_upgrade.go deleted file mode 100644 index 3019fb80..00000000 --- a/internal/cmd/cmd_upgrade.go +++ /dev/null @@ -1,23 +0,0 @@ -package cmd - -import ( - core "github.com/dosco/graphjin/v2/core" - "github.com/spf13/cobra" -) - -func upgradeCmd() *cobra.Command { - c := &cobra.Command{ - Use: "upgrade", - Short: "Upgrade a GraphJin app", - Run: cmdUpgrade, - } - return c -} - -func cmdUpgrade(cmd *cobra.Command, args []string) { - if err := core.Upgrade(cpath); err != nil { - log.Fatalf("%s", err) - } - log.Infof("please delete the .yml/.yaml query and old fragment files under %s/queries", cpath) - log.Infoln("upgrade completed!") -} diff --git a/package.json b/package.json index 6fd8bbe4..3a03f536 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "graphjin", - "version": "2.0.18", + "version": "2.0.19", "description": "GraphJin - Build APIs in 5 minutes with GraphQL", "type": "module", "main": "./wasm/js/graphjin.js", diff --git a/serv/health.go b/serv/health.go index 83fb3b74..1e2c6509 100644 --- a/serv/health.go +++ b/serv/health.go @@ -4,7 +4,6 @@ import ( "context" "net/http" - "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -13,23 +12,18 @@ var healthyResponse = []byte("All's Well") func healthV1Handler(s1 *Service) http.Handler { h := func(w http.ResponseWriter, r *http.Request) { - var span trace.Span - s := s1.Load().(*service) c, cancel := context.WithTimeout(r.Context(), s.conf.DB.PingTimeout) defer cancel() - c, span = s.spanStart(c, "Health Check Request") - err := s.db.PingContext(c) - if err != nil { + c1, span := s.spanStart(c, "Health Check Request") + defer span.End() + + if err := s.db.PingContext(c1); err != nil { spanError(span, err) - } - span.End() - if err != nil { s.zlog.Error("Health Check", []zapcore.Field{zap.Error(err)}...) w.WriteHeader(http.StatusInternalServerError) - return } _, _ = w.Write(healthyResponse) diff --git a/wasm/graphjin.wasm b/wasm/graphjin.wasm index 33225054..dddbdb28 100755 Binary files a/wasm/graphjin.wasm and b/wasm/graphjin.wasm differ