Skip to content

Commit

Permalink
fix: role resoltion not working on subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Mar 7, 2021
1 parent efecda9 commit 9a9715c
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 54 deletions.
45 changes: 26 additions & 19 deletions core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,30 @@ func (gj *GraphJin) initCompilers() error {
return nil
}

func (gj *GraphJin) executeRoleQuery(c context.Context, conn *sql.Conn) (string, error) {
var role string
var ar args
var err error

if conn == nil {
if conn, err = gj.db.Conn(c); err != nil {
return role, err
}
defer conn.Close()
}

if c.Value(UserIDKey) == nil {
return "anon", nil
}

if ar, err = gj.roleQueryArgList(c); err != nil {
return "", err
}

err = conn.QueryRowContext(c, gj.roleStmt, ar.values...).Scan(&role)
return role, err
}

func (c *scontext) execQuery(query string, vars []byte, role string) (qres, error) {
res, err := c.resolveSQL(query, vars, role)
if err != nil {
Expand Down Expand Up @@ -188,7 +212,7 @@ func (c *scontext) resolveSQL(query string, vars []byte, role string) (qres, err
res.role = v.(string)

} else if c.gj.abacEnabled {
res.role, err = c.executeRoleQuery(conn)
res.role, err = c.gj.executeRoleQuery(c, conn)
}

if err != nil {
Expand Down Expand Up @@ -245,23 +269,6 @@ func (c *scontext) resolveSQL(query string, vars []byte, role string) (qres, err
return res, nil
}

func (c *scontext) executeRoleQuery(conn *sql.Conn) (string, error) {
var role string
var ar args
var err error

if c.Value(UserIDKey) == nil {
return "anon", nil
}

if ar, err = c.gj.roleQueryArgList(c); err != nil {
return "", err
}

err = conn.QueryRowContext(c, c.gj.roleStmt, ar.values...).Scan(&role)
return role, err
}

func (c *scontext) setLocalUserID(conn *sql.Conn) error {
var err error

Expand Down Expand Up @@ -355,7 +362,7 @@ func (r *Result) SQL() string {
func (c *scontext) debugLog(st *stmt) {
for _, sel := range st.qc.Selects {
if sel.SkipRender == qcode.SkipTypeUserNeeded {
c.gj.log.Printf("Field skipped: %s, Requires $user_id or table not added to anon role", sel.FieldName)
c.gj.log.Printf("Field skipped, requires $user_id or table not added to anon role: %s", sel.FieldName)
}
}
}
13 changes: 4 additions & 9 deletions core/internal/qcode/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,11 @@ func (co *Compiler) getRole(role, schema, table, field string) trval {

// For anon roles when a trval is not found return the default trval
tr, ok := co.tr[k]
if ok {
return tr
}
if role != "anon" {
tr.role = role
} else {
tr = co.c.defTrv
tr.role = role
}
tr.role = role

if !ok && role == "anon" {
return co.c.defTrv
}
return tr
}

Expand Down
13 changes: 9 additions & 4 deletions core/internal/qcode/qcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1147,13 +1147,18 @@ func compileFilter(s *sdata.DBSchema, ti sdata.DBTable, filter []string, isJSON
// returning a nil 'f' this needs to be fixed

// TODO: Invalid where clauses such as missing op (eg. eq) also fail silently

if fl == nil {
fl = f
} else {
fl = newExpOp(OpAnd)
fl.Children = append(fl.Children, f)
if len(filter) == 1 {
fl = f
continue
} else {
fl = newExpOp(OpAnd)
}
}
fl.Children = append(fl.Children, f)
}

return fl, needsUser, nil
}

Expand Down
52 changes: 30 additions & 22 deletions core/internal/sdata/dwg.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,9 @@ type graphResult struct {
func (s *DBSchema) between(from, to []edgeInfo, through string) (*graphResult, error) {
for _, f := range from {
for _, t := range to {
res, err := s.pickPath(f, t, through)
if err != nil {
if res, err := s.pickPath(f, t, through); err != nil {
return nil, err
}
if res != nil {
} else if res != nil {
return res, nil
}
}
Expand All @@ -265,24 +263,16 @@ func (s *DBSchema) between(from, to []edgeInfo, through string) (*graphResult, e
}

func (s *DBSchema) pickPath(f, t edgeInfo, through string) (*graphResult, error) {
var err error

fn := f.nodeID
tn := t.nodeID
paths := s.rg.AllPaths(fn, tn)

if through != "" {
var npaths [][]int32
v, ok := s.tindex[(s.DBSchema() + ":" + through)]
if !ok {
return nil, ErrThoughNodeNotFound
if paths, err = s.pickThroughPath(paths, through); err != nil {
return nil, err
}
for i := range paths {
for j := range paths[i] {
if paths[i][j] == v.nodeID {
npaths = append(npaths, paths[i])
}
}
}
paths = npaths
}

for _, nodes := range paths {
Expand All @@ -302,7 +292,7 @@ func (s *DBSchema) pickPath(f, t edgeInfo, through string) (*graphResult, error)
fn := nodes[i-1]
tn := nodes[i]
lines := s.rg.GetEdges(fn, tn)
// printLines(lines)
// s.printLines(lines)

switch i {
case 1:
Expand All @@ -329,6 +319,23 @@ func (s *DBSchema) pickPath(f, t edgeInfo, through string) (*graphResult, error)
return nil, nil
}

func (s *DBSchema) pickThroughPath(paths [][]int32, through string) ([][]int32, error) {
var npaths [][]int32
v, ok := s.tindex[(s.DBSchema() + ":" + through)]
if !ok {
return nil, ErrThoughNodeNotFound
}

for i := range paths {
for j := range paths[i] {
if paths[i][j] == v.nodeID {
npaths = append(npaths, paths[i])
}
}
}
return npaths, nil
}

func pickLine(lines []util.Edge, ei edgeInfo) *util.Edge {
for _, v := range lines {
for _, eid := range ei.edgeIDs {
Expand Down Expand Up @@ -361,11 +368,12 @@ func minWeightedLine(lines []util.Edge) *util.Edge {
return line
}

// func printLines(lines []util.Edge) {
// func (s *DBSchema) printLines(lines []util.Edge) {
// for _, v := range lines {
// for lines.Next() {
// e := (lines.WeightedLine()).(TEdge)
// fmt.Printf("- (%d) %d -> %d\n", e.ID(), e.From().ID(), e.To().ID())
// e := s.ae[v.ID]
// f := s.tables[e.From]
// t := s.tables[e.To]
// fmt.Printf("- (%d) %s %d -> %s %d\n", v.ID, f.Name, e.From, t.Name, e.To)
// }
// lines.Reset()
// fmt.Println("---")
// }
6 changes: 6 additions & 0 deletions core/subs.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ func (gj *GraphJin) Subscribe(
role = "anon"
}

if role == "user" && gj.abacEnabled {
if role, err = gj.executeRoleQuery(c, nil); err != nil {
return nil, err
}
}

v, _ := gj.subs.LoadOrStore((name + role), &sub{
name: name,
role: role,
Expand Down

0 comments on commit 9a9715c

Please sign in to comment.