Skip to content

Commit

Permalink
Fix ACL checks for CTEs (#16642)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 authored Sep 4, 2024
1 parent d276007 commit bf9d064
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 20 deletions.
89 changes: 69 additions & 20 deletions go/vt/vttablet/tabletserver/planbuilder/permission.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission {
case *sqlparser.Union:
permissions = buildSubqueryPermissions(node, tableacl.READER, permissions)
case *sqlparser.Insert:
permissions = buildTableExprPermissions(node.Table, tableacl.WRITER, permissions)
permissions = buildTableExprPermissions(node.Table, tableacl.WRITER, nil, permissions)
permissions = buildSubqueryPermissions(node, tableacl.READER, permissions)
case *sqlparser.Update:
permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, permissions)
permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, nil, permissions)
permissions = buildSubqueryPermissions(node, tableacl.READER, permissions)
case *sqlparser.Delete:
permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, permissions)
permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, nil, permissions)
permissions = buildSubqueryPermissions(node, tableacl.READER, permissions)
case sqlparser.DDLStatement:
for _, t := range node.AffectedTables() {
permissions = buildTableNamePermissions(t, tableacl.ADMIN, permissions)
permissions = buildTableNamePermissions(t, tableacl.ADMIN, nil, permissions)
}
case
*sqlparser.AlterMigration,
Expand All @@ -66,10 +66,10 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission {
permissions = []Permission{} // TODO(shlomi) what are the correct permissions here? Table is unknown
case *sqlparser.Flush:
for _, t := range node.TableNames {
permissions = buildTableNamePermissions(t, tableacl.ADMIN, permissions)
permissions = buildTableNamePermissions(t, tableacl.ADMIN, nil, permissions)
}
case *sqlparser.Analyze:
permissions = buildTableNamePermissions(node.Table, tableacl.WRITER, permissions)
permissions = buildTableNamePermissions(node.Table, tableacl.WRITER, nil, permissions)
case *sqlparser.OtherAdmin, *sqlparser.CallProc, *sqlparser.Begin, *sqlparser.Commit, *sqlparser.Rollback,
*sqlparser.Load, *sqlparser.Savepoint, *sqlparser.Release, *sqlparser.SRollback, *sqlparser.Set, *sqlparser.Show, sqlparser.Explain,
*sqlparser.UnlockTables:
Expand All @@ -81,43 +81,92 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission {
}

func buildSubqueryPermissions(stmt sqlparser.Statement, role tableacl.Role, permissions []Permission) []Permission {
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
if sel, ok := node.(*sqlparser.Select); ok {
permissions = buildTableExprsPermissions(sel.From, role, permissions)
var cteScopes [][]sqlparser.IdentifierCS
sqlparser.Rewrite(stmt, func(cursor *sqlparser.Cursor) bool {
switch node := cursor.Node().(type) {
case *sqlparser.Select:
if node.With != nil {
cteScopes = append(cteScopes, gatherCTEs(node.With))
}
var ctes []sqlparser.IdentifierCS
for _, cteScope := range cteScopes {
ctes = append(ctes, cteScope...)
}
permissions = buildTableExprsPermissions(node.From, role, ctes, permissions)
case *sqlparser.Delete:
if node.With != nil {
cteScopes = append(cteScopes, gatherCTEs(node.With))
}
case *sqlparser.Update:
if node.With != nil {
cteScopes = append(cteScopes, gatherCTEs(node.With))
}
case *sqlparser.Union:
if node.With != nil {
cteScopes = append(cteScopes, gatherCTEs(node.With))
}
}
return true, nil
}, stmt)
return true
}, func(cursor *sqlparser.Cursor) bool {
// When we encounter a With expression coming up, we should remove
// the last value from the cte scopes to ensure we none of the outer
// elements of the query see this table name.
_, isWith := cursor.Node().(*sqlparser.With)
if isWith {
cteScopes = cteScopes[:len(cteScopes)-1]
}
return true
})
return permissions
}

func buildTableExprsPermissions(node []sqlparser.TableExpr, role tableacl.Role, permissions []Permission) []Permission {
// gatherCTEs gathers the CTEs from the WITH clause.
func gatherCTEs(with *sqlparser.With) []sqlparser.IdentifierCS {
var ctes []sqlparser.IdentifierCS
for _, cte := range with.CTEs {
ctes = append(ctes, cte.ID)
}
return ctes
}

func buildTableExprsPermissions(node []sqlparser.TableExpr, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission {
for _, node := range node {
permissions = buildTableExprPermissions(node, role, permissions)
permissions = buildTableExprPermissions(node, role, ctes, permissions)
}
return permissions
}

func buildTableExprPermissions(node sqlparser.TableExpr, role tableacl.Role, permissions []Permission) []Permission {
func buildTableExprPermissions(node sqlparser.TableExpr, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission {
switch node := node.(type) {
case *sqlparser.AliasedTableExpr:
// An AliasedTableExpr can also be a derived table, but we should skip them here
// because the buildSubQueryPermissions walker will catch them and extract
// the corresponding table names.
if tblName, ok := node.Expr.(sqlparser.TableName); ok {
permissions = buildTableNamePermissions(tblName, role, permissions)
permissions = buildTableNamePermissions(tblName, role, ctes, permissions)
}
case *sqlparser.ParenTableExpr:
permissions = buildTableExprsPermissions(node.Exprs, role, permissions)
permissions = buildTableExprsPermissions(node.Exprs, role, ctes, permissions)
case *sqlparser.JoinTableExpr:
permissions = buildTableExprPermissions(node.LeftExpr, role, permissions)
permissions = buildTableExprPermissions(node.RightExpr, role, permissions)
permissions = buildTableExprPermissions(node.LeftExpr, role, ctes, permissions)
permissions = buildTableExprPermissions(node.RightExpr, role, ctes, permissions)
}
return permissions
}

func buildTableNamePermissions(node sqlparser.TableName, role tableacl.Role, permissions []Permission) []Permission {
func buildTableNamePermissions(node sqlparser.TableName, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission {
tableName := node.Name.String()
// Check whether this table is a cte or not.
// If the table name is qualified, then it cannot be a cte.
if node.Qualifier.IsEmpty() {
for _, cte := range ctes {
if cte.String() == tableName {
return permissions
}
}
}
permissions = append(permissions, Permission{
TableName: node.Name.String(),
TableName: tableName,
Role: role,
})
return permissions
Expand Down
39 changes: 39 additions & 0 deletions go/vt/vttablet/tabletserver/planbuilder/permission_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,45 @@ func TestBuildPermissions(t *testing.T) {
TableName: "seq",
Role: tableacl.WRITER,
}},
}, {
input: "with t as (select count(*) as a from user) select a from t",
output: []Permission{{
TableName: "user",
Role: tableacl.READER,
}},
}, {
input: "with d as (select id, count(*) as a from user) select d.a from music join d on music.user_id = d.id group by 1",
output: []Permission{{
TableName: "music",
Role: tableacl.READER,
}, {
TableName: "user",
Role: tableacl.READER,
}},
}, {
input: "WITH t1 AS ( SELECT id FROM t2 ) SELECT * FROM t1 JOIN ks.t1 AS t3",
output: []Permission{{
TableName: "t1",
Role: tableacl.READER,
}, {
TableName: "t2",
Role: tableacl.READER,
}},
}, {
input: "WITH RECURSIVE t1 (n) AS ( SELECT id from t2 UNION ALL SELECT n + 1 FROM t1 WHERE n < 5 ) SELECT * FROM t1 JOIN t1 AS t3",
output: []Permission{{
TableName: "t2",
Role: tableacl.READER,
}},
}, {
input: "(with t1 as (select count(*) as a from user) select a from t1) union select * from t1",
output: []Permission{{
TableName: "user",
Role: tableacl.READER,
}, {
TableName: "t1",
Role: tableacl.READER,
}},
}}

for _, tcase := range tcases {
Expand Down

0 comments on commit bf9d064

Please sign in to comment.