diff --git a/core/internal/graph/parse.go b/core/internal/graph/parse.go index fc20c11d..e430c8c9 100644 --- a/core/internal/graph/parse.go +++ b/core/internal/graph/parse.go @@ -32,6 +32,7 @@ const ( NodeObj NodeList NodeVar + NodeLabel ) type FieldType int8 @@ -727,10 +728,10 @@ func (p *Parser) parseValue() (*Node, error) { node.Type = NodeStr case itemBoolVal: node.Type = NodeBool - case itemName: - node.Type = NodeStr case itemVariable: node.Type = NodeVar + case itemName: + node.Type = NodeLabel default: return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not '%s' of type '%s')", p.val(item), item._type) diff --git a/core/internal/psql/fn.go b/core/internal/psql/fn.go index 601fdf79..bfcc9f3e 100644 --- a/core/internal/psql/fn.go +++ b/core/internal/psql/fn.go @@ -91,7 +91,7 @@ func (c *compilerContext) renderFuncArgVal(a qcode.Arg) { case qcode.ArgTypeCol: c.colWithTable(a.Col.Table, a.Col.Name) case qcode.ArgTypeVar: - c.renderParam(Param{Name: a.Val, Type: a.ValType}) + c.renderParam(Param{Name: a.Val, Type: a.DType}) default: c.squoted(a.Val) } diff --git a/core/internal/qcode/fields.go b/core/internal/qcode/fields.go index 80a9e6e0..0ad90fae 100644 --- a/core/internal/qcode/fields.go +++ b/core/internal/qcode/fields.go @@ -98,7 +98,7 @@ func (co *Compiler) compileChildColumns( field.Func = fn.Func field.Args = fn.Args - if err := co.compileFuncArgs(&field, f.Args); err != nil { + if err := co.compileFuncArgs(sel, &field, f.Args); err != nil { return err } @@ -135,31 +135,34 @@ func (co *Compiler) compileFuncTableArg(sel *Select, arg *graph.Arg) error { if err != nil { return fmt.Errorf("db function %s: %w", fn.Name, err) } - a := Arg{ - Name: arg.Name, - Val: arg.Val.Val, - ValType: input.Type, - } - if arg.Val.Type == graph.NodeVar { + + a := Arg{Name: arg.Name, DType: input.Type} + + switch arg.Val.Type { + case graph.NodeLabel: + a.Type = ArgTypeCol + a.Col, err = sel.Ti.GetColumn(arg.Val.Val) + case graph.NodeVar: a.Type = ArgTypeVar + fallthrough + default: + a.Val = arg.Val.Val + } + if err != nil { + return err } - // if arg.Val.Type = graph. - // fn.Col, err = sel.Ti.GetColumn(fname[(len(fn.Name) + 1):]) - // if err != nil { - // return - // } sel.Args = append(sel.Args, a) return nil } -func (co *Compiler) compileFuncArgs(f *Field, args []graph.Arg) error { +func (co *Compiler) compileFuncArgs(sel *Select, f *Field, args []graph.Arg) error { if len(args) != 0 && len(f.Func.Inputs) == 0 { return fmt.Errorf("db function '%s' does not have any arguments", f.Func.Name) } for _, arg := range args { if arg.Name == "args" { - if err := co.compileFuncArgArgs(f, arg); err != nil { + if err := co.compileFuncArgArgs(sel, f, arg); err != nil { return err } continue @@ -168,40 +171,55 @@ func (co *Compiler) compileFuncArgs(f *Field, args []graph.Arg) error { if err != nil { return fmt.Errorf("db function %s: %w", f.Func.Name, err) } - a := Arg{ - Name: arg.Name, - Val: arg.Val.Val, - ValType: input.Type, - } - if arg.Val.Type == graph.NodeVar { + + a := Arg{Name: arg.Name, DType: input.Type} + + switch arg.Val.Type { + case graph.NodeLabel: + a.Type = ArgTypeCol + a.Col, err = sel.Ti.GetColumn(arg.Val.Val) + case graph.NodeVar: a.Type = ArgTypeVar + fallthrough + default: + a.Val = arg.Val.Val + } + if err != nil { + return err } - // if arg.Val.Type = graph. - // fn.Col, err = sel.Ti.GetColumn(fname[(len(fn.Name) + 1):]) - // if err != nil { - // return - // } f.Args = append(f.Args, a) } return nil } -func (co *Compiler) compileFuncArgArgs(f *Field, arg graph.Arg) error { +func (co *Compiler) compileFuncArgArgs(sel *Select, f *Field, arg graph.Arg) error { if len(f.Func.Inputs) == 0 { return fmt.Errorf("db function '%s' does not have any arguments", f.Func.Name) } node := arg.Val - if node.Type != graph.NodeList { return argErr("args", "list") } + var err error + for i, n := range node.Children { - a := Arg{Val: n.Val, ValType: f.Func.Inputs[i].Type} - if n.Type == graph.NodeVar { + a := Arg{DType: f.Func.Inputs[i].Type} + + switch n.Type { + case graph.NodeLabel: + a.Type = ArgTypeCol + a.Col, err = sel.Ti.GetColumn(n.Val) + case graph.NodeVar: a.Type = ArgTypeVar + a.Val = n.Val + default: + a.Val = n.Val + } + if err != nil { + return err } f.Args = append(f.Args, a) } diff --git a/core/internal/qcode/qcode.go b/core/internal/qcode/qcode.go index 0a970916..6a64d347 100644 --- a/core/internal/qcode/qcode.go +++ b/core/internal/qcode/qcode.go @@ -191,11 +191,11 @@ const ( ) type Arg struct { - Type ArgType - Name string - Val string - ValType string - Col sdata.DBColumn + Type ArgType + DType string + Name string + Val string + Col sdata.DBColumn } type OrderBy struct { @@ -1564,7 +1564,8 @@ func (co *Compiler) compileArgOrderByObj(sel *Select, parent *graph.Node, cm map // Check for type if node.Type != graph.NodeStr && node.Type != graph.NodeObj && - node.Type != graph.NodeList { + node.Type != graph.NodeList && + node.Type != graph.NodeLabel { err = fmt.Errorf("expecting a string, object or list") continue } @@ -1574,7 +1575,7 @@ func (co *Compiler) compileArgOrderByObj(sel *Select, parent *graph.Node, cm map cn := node switch node.Type { - case graph.NodeStr: + case graph.NodeStr, graph.NodeLabel: if ob.Order, err = toOrder(node.Val); err != nil { // sets the asc desc etc continue } @@ -1694,9 +1695,21 @@ func (co *Compiler) compileArgArgs(sel *Select, arg *graph.Arg) error { } for i, n := range node.Children { - a := Arg{Val: n.Val, ValType: fn.Inputs[i].Type} - if n.Type == graph.NodeVar { + var err error + a := Arg{DType: fn.Inputs[i].Type} + + switch n.Type { + case graph.NodeLabel: + a.Type = ArgTypeCol + a.Col, err = sel.Ti.GetColumn(n.Val) + case graph.NodeVar: a.Type = ArgTypeVar + fallthrough + default: + a.Val = n.Val + } + if err != nil { + return err } sel.Args = append(sel.Args, a) } diff --git a/core/query_pg_test.go b/core/query_pg_test.go index 62ee26ad..bf60359f 100644 --- a/core/query_pg_test.go +++ b/core/query_pg_test.go @@ -19,7 +19,7 @@ func Example_queryWithFunctionFields() { products(id: 51) { id name - is_hot_product(id: 51) + is_hot_product(id: id) } }`