Skip to content

Commit

Permalink
feat: add ability to use columns in function arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Dec 2, 2022
1 parent 2ddbc0e commit b971b74
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 42 deletions.
5 changes: 3 additions & 2 deletions core/internal/graph/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const (
NodeObj
NodeList
NodeVar
NodeLabel
)

type FieldType int8
Expand Down Expand Up @@ -727,10 +728,10 @@ func (p *Parser) parseValue() (*Node, error) {
node.Type = NodeStr
case itemBoolVal:
node.Type = NodeBool
case itemName:
node.Type = NodeStr
case itemVariable:
node.Type = NodeVar
case itemName:
node.Type = NodeLabel
default:
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not '%s' of type '%s')", p.val(item), item._type)

Expand Down
2 changes: 1 addition & 1 deletion core/internal/psql/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (c *compilerContext) renderFuncArgVal(a qcode.Arg) {
case qcode.ArgTypeCol:
c.colWithTable(a.Col.Table, a.Col.Name)
case qcode.ArgTypeVar:
c.renderParam(Param{Name: a.Val, Type: a.ValType})
c.renderParam(Param{Name: a.Val, Type: a.DType})
default:
c.squoted(a.Val)
}
Expand Down
76 changes: 47 additions & 29 deletions core/internal/qcode/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (co *Compiler) compileChildColumns(
field.Func = fn.Func
field.Args = fn.Args

if err := co.compileFuncArgs(&field, f.Args); err != nil {
if err := co.compileFuncArgs(sel, &field, f.Args); err != nil {
return err
}

Expand Down Expand Up @@ -135,31 +135,34 @@ func (co *Compiler) compileFuncTableArg(sel *Select, arg *graph.Arg) error {
if err != nil {
return fmt.Errorf("db function %s: %w", fn.Name, err)
}
a := Arg{
Name: arg.Name,
Val: arg.Val.Val,
ValType: input.Type,
}
if arg.Val.Type == graph.NodeVar {

a := Arg{Name: arg.Name, DType: input.Type}

switch arg.Val.Type {
case graph.NodeLabel:
a.Type = ArgTypeCol
a.Col, err = sel.Ti.GetColumn(arg.Val.Val)
case graph.NodeVar:
a.Type = ArgTypeVar
fallthrough
default:
a.Val = arg.Val.Val
}
if err != nil {
return err
}
// if arg.Val.Type = graph.
// fn.Col, err = sel.Ti.GetColumn(fname[(len(fn.Name) + 1):])
// if err != nil {
// return
// }
sel.Args = append(sel.Args, a)
return nil
}

func (co *Compiler) compileFuncArgs(f *Field, args []graph.Arg) error {
func (co *Compiler) compileFuncArgs(sel *Select, f *Field, args []graph.Arg) error {
if len(args) != 0 && len(f.Func.Inputs) == 0 {
return fmt.Errorf("db function '%s' does not have any arguments", f.Func.Name)
}

for _, arg := range args {
if arg.Name == "args" {
if err := co.compileFuncArgArgs(f, arg); err != nil {
if err := co.compileFuncArgArgs(sel, f, arg); err != nil {
return err
}
continue
Expand All @@ -168,40 +171,55 @@ func (co *Compiler) compileFuncArgs(f *Field, args []graph.Arg) error {
if err != nil {
return fmt.Errorf("db function %s: %w", f.Func.Name, err)
}
a := Arg{
Name: arg.Name,
Val: arg.Val.Val,
ValType: input.Type,
}
if arg.Val.Type == graph.NodeVar {

a := Arg{Name: arg.Name, DType: input.Type}

switch arg.Val.Type {
case graph.NodeLabel:
a.Type = ArgTypeCol
a.Col, err = sel.Ti.GetColumn(arg.Val.Val)
case graph.NodeVar:
a.Type = ArgTypeVar
fallthrough
default:
a.Val = arg.Val.Val
}
if err != nil {
return err
}
// if arg.Val.Type = graph.
// fn.Col, err = sel.Ti.GetColumn(fname[(len(fn.Name) + 1):])
// if err != nil {
// return
// }
f.Args = append(f.Args, a)
}

return nil
}

func (co *Compiler) compileFuncArgArgs(f *Field, arg graph.Arg) error {
func (co *Compiler) compileFuncArgArgs(sel *Select, f *Field, arg graph.Arg) error {
if len(f.Func.Inputs) == 0 {
return fmt.Errorf("db function '%s' does not have any arguments", f.Func.Name)
}

node := arg.Val

if node.Type != graph.NodeList {
return argErr("args", "list")
}

var err error

for i, n := range node.Children {
a := Arg{Val: n.Val, ValType: f.Func.Inputs[i].Type}
if n.Type == graph.NodeVar {
a := Arg{DType: f.Func.Inputs[i].Type}

switch n.Type {
case graph.NodeLabel:
a.Type = ArgTypeCol
a.Col, err = sel.Ti.GetColumn(n.Val)
case graph.NodeVar:
a.Type = ArgTypeVar
a.Val = n.Val
default:
a.Val = n.Val
}
if err != nil {
return err
}
f.Args = append(f.Args, a)
}
Expand Down
31 changes: 22 additions & 9 deletions core/internal/qcode/qcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ const (
)

type Arg struct {
Type ArgType
Name string
Val string
ValType string
Col sdata.DBColumn
Type ArgType
DType string
Name string
Val string
Col sdata.DBColumn
}

type OrderBy struct {
Expand Down Expand Up @@ -1564,7 +1564,8 @@ func (co *Compiler) compileArgOrderByObj(sel *Select, parent *graph.Node, cm map
// Check for type
if node.Type != graph.NodeStr &&
node.Type != graph.NodeObj &&
node.Type != graph.NodeList {
node.Type != graph.NodeList &&
node.Type != graph.NodeLabel {
err = fmt.Errorf("expecting a string, object or list")
continue
}
Expand All @@ -1574,7 +1575,7 @@ func (co *Compiler) compileArgOrderByObj(sel *Select, parent *graph.Node, cm map
cn := node

switch node.Type {
case graph.NodeStr:
case graph.NodeStr, graph.NodeLabel:
if ob.Order, err = toOrder(node.Val); err != nil { // sets the asc desc etc
continue
}
Expand Down Expand Up @@ -1694,9 +1695,21 @@ func (co *Compiler) compileArgArgs(sel *Select, arg *graph.Arg) error {
}

for i, n := range node.Children {
a := Arg{Val: n.Val, ValType: fn.Inputs[i].Type}
if n.Type == graph.NodeVar {
var err error
a := Arg{DType: fn.Inputs[i].Type}

switch n.Type {
case graph.NodeLabel:
a.Type = ArgTypeCol
a.Col, err = sel.Ti.GetColumn(n.Val)
case graph.NodeVar:
a.Type = ArgTypeVar
fallthrough
default:
a.Val = n.Val
}
if err != nil {
return err
}
sel.Args = append(sel.Args, a)
}
Expand Down
2 changes: 1 addition & 1 deletion core/query_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func Example_queryWithFunctionFields() {
products(id: 51) {
id
name
is_hot_product(id: 51)
is_hot_product(id: id)
}
}`

Expand Down

0 comments on commit b971b74

Please sign in to comment.