diff --git a/go.mod b/go.mod index db3bae0b0a..a37ac033a6 100644 --- a/go.mod +++ b/go.mod @@ -12,9 +12,9 @@ require ( github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241104143128-c2bb78c109df github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 - github.com/dolthub/go-mysql-server v0.18.2-0.20241115193357-2d21230229d1 + github.com/dolthub/go-mysql-server v0.18.2-0.20241119011039-4d6202a92c5f github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 - github.com/dolthub/vitess v0.0.0-20241111235433-a20a5ab9d7c9 + github.com/dolthub/vitess v0.0.0-20241119005402-6a198321d993 github.com/fatih/color v1.13.0 github.com/goccy/go-json v0.10.2 github.com/gogo/protobuf v1.3.2 diff --git a/go.sum b/go.sum index 489a3ad96d..662e283599 100644 --- a/go.sum +++ b/go.sum @@ -224,8 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 h1:aC17hZD6iwzBwwfO5M+3oBT5E5gGRiQPdn+vzpDXqIA= github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.18.2-0.20241115193357-2d21230229d1 h1:FfUUxob0uurW8D8z25GfgEmBwL+dl1zWWkf85iCsnUI= -github.com/dolthub/go-mysql-server v0.18.2-0.20241115193357-2d21230229d1/go.mod h1:sOMQzWUvHvJECzpcUxjDgV5BR/A7U+hOh596PUO2NPI= +github.com/dolthub/go-mysql-server v0.18.2-0.20241119011039-4d6202a92c5f h1:gWnRFJyo3fuXXO80uTH+/2n+qc+0TwofvwgVQ4e49gU= +github.com/dolthub/go-mysql-server v0.18.2-0.20241119011039-4d6202a92c5f/go.mod h1:uPKS0kU0pd1l/9RVVFe4i+/cqqxxGuhnYZZzE9xwc2U= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= @@ -238,8 +238,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= github.com/dolthub/swiss v0.1.0 h1:EaGQct3AqeP/MjASHLiH6i4TAmgbG/c4rA6a1bzCOPc= github.com/dolthub/swiss v0.1.0/go.mod h1:BeucyB08Vb1G9tumVN3Vp/pyY4AMUnr9p7Rz7wJ7kAQ= -github.com/dolthub/vitess v0.0.0-20241111235433-a20a5ab9d7c9 h1:s36zDuLPuZRWC0nBCJs2Z8joP19eKEtcsIsuE8K9Kx0= -github.com/dolthub/vitess v0.0.0-20241111235433-a20a5ab9d7c9/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/dolthub/vitess v0.0.0-20241119005402-6a198321d993 h1:MhD6jHjshx2djyUq/uZxtCyHBYAnE3WshhJDUaO9fD8= +github.com/dolthub/vitess v0.0.0-20241119005402-6a198321d993/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= diff --git a/server/ast/aliased_table_expr.go b/server/ast/aliased_table_expr.go index a12af43f2c..25373c733b 100644 --- a/server/ast/aliased_table_expr.go +++ b/server/ast/aliased_table_expr.go @@ -43,8 +43,8 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (*vitess.Al aliasExpr = tableName authInfo = vitess.AuthInformation{ AuthType: ctx.Auth().PeekAuthType(), - TargetType: auth.AuthTargetType_SingleTableIdentifier, - TargetNames: []string{tableName.SchemaQualifier.String(), tableName.Name.String()}, + TargetType: auth.AuthTargetType_TableIdentifiers, + TargetNames: []string{tableName.DbQualifier.String(), tableName.SchemaQualifier.String(), tableName.Name.String()}, } case *tree.Subquery: tableExpr, err := nodeTableExpr(ctx, expr) diff --git a/server/ast/create_schema.go b/server/ast/create_schema.go index ba982f5064..f2622adc7a 100644 --- a/server/ast/create_schema.go +++ b/server/ast/create_schema.go @@ -17,6 +17,8 @@ package ast import ( vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/doltgresql/server/auth" + "github.com/dolthub/doltgresql/postgres/parser/sem/tree" ) @@ -25,13 +27,16 @@ func nodeCreateSchema(ctx *Context, node *tree.CreateSchema) (vitess.Statement, if node == nil { return nil, nil } - return &vitess.DBDDL{ Action: "CREATE", SchemaOrDatabase: "schema", DBName: node.Schema, IfNotExists: node.IfNotExists, CharsetCollate: nil, // TODO - // TODO: AuthRole + Auth: vitess.AuthInformation{ + AuthType: auth.AuthType_CREATE, + TargetType: auth.AuthTargetType_DatabaseIdentifiers, + TargetNames: []string{""}, + }, }, nil } diff --git a/server/ast/create_table.go b/server/ast/create_table.go index 0f52614b9e..b1f9261b46 100644 --- a/server/ast/create_table.go +++ b/server/ast/create_table.go @@ -20,6 +20,7 @@ import ( vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" + "github.com/dolthub/doltgresql/server/auth" ) // nodeCreateTable handles *tree.CreateTable nodes. @@ -87,6 +88,11 @@ func nodeCreateTable(ctx *Context, node *tree.CreateTable) (*vitess.DDL, error) Temporary: isTemporary, OptSelect: optSelect, OptLike: optLike, + Auth: vitess.AuthInformation{ + AuthType: auth.AuthType_CREATE, + TargetType: auth.AuthTargetType_SchemaIdentifiers, + TargetNames: []string{tableName.DbQualifier.String(), tableName.SchemaQualifier.String()}, + }, } if err = assignTableDefs(ctx, node.Defs, ddl); err != nil { return nil, err diff --git a/server/ast/drop_table.go b/server/ast/drop_table.go index ad815959e6..5e52c28fbd 100644 --- a/server/ast/drop_table.go +++ b/server/ast/drop_table.go @@ -20,6 +20,7 @@ import ( vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" + "github.com/dolthub/doltgresql/server/auth" ) // nodeDropTable handles *tree.DropTable nodes. @@ -36,16 +37,24 @@ func nodeDropTable(ctx *Context, node *tree.DropTable) (*vitess.DDL, error) { return nil, fmt.Errorf("CASCADE is not yet supported") } tableNames := make([]vitess.TableName, len(node.Names)) + authTableNames := make([]string, 0, len(node.Names)*3) for i := range node.Names { var err error tableNames[i], err = nodeTableName(ctx, &node.Names[i]) if err != nil { return nil, err } + authTableNames = append(authTableNames, + tableNames[i].DbQualifier.String(), tableNames[i].SchemaQualifier.String(), tableNames[i].Name.String()) } return &vitess.DDL{ Action: vitess.DropStr, FromTables: tableNames, IfExists: node.IfExists, + Auth: vitess.AuthInformation{ + AuthType: auth.AuthType_DROPTABLE, + TargetType: auth.AuthTargetType_Ignore, + TargetNames: authTableNames, + }, }, nil } diff --git a/server/ast/grant.go b/server/ast/grant.go index c42f243509..b4d6c405ab 100644 --- a/server/ast/grant.go +++ b/server/ast/grant.go @@ -32,10 +32,12 @@ func nodeGrant(ctx *Context, node *tree.Grant) (vitess.Statement, error) { return nil, nil } var grantTable *pgnodes.GrantTable + var grantSchema *pgnodes.GrantSchema + var grantDatabase *pgnodes.GrantDatabase switch node.Targets.TargetType { case privilege.Table: - tables := make([]doltdb.TableName, len(node.Targets.Tables)) - for i, table := range node.Targets.Tables { + tables := make([]doltdb.TableName, 0, len(node.Targets.Tables)+len(node.Targets.InSchema)) + for _, table := range node.Targets.Tables { normalizedTable, err := table.NormalizeTablePattern() if err != nil { return nil, err @@ -45,24 +47,50 @@ func nodeGrant(ctx *Context, node *tree.Grant) (vitess.Statement, error) { if normalizedTable.ExplicitCatalog { return nil, fmt.Errorf("granting privileges to other databases is not yet supported") } - tables[i] = doltdb.TableName{ + tables = append(tables, doltdb.TableName{ Name: string(normalizedTable.ObjectName), Schema: string(normalizedTable.SchemaName), - } + }) case *tree.AllTablesSelector: - return nil, fmt.Errorf("selecting all tables in a schema is not yet supported") + tables = append(tables, doltdb.TableName{ + Name: "", + Schema: string(normalizedTable.SchemaName), + }) default: return nil, fmt.Errorf(`unexpected table type in GRANT: %T`, normalizedTable) } } + for _, schema := range node.Targets.InSchema { + tables = append(tables, doltdb.TableName{ + Name: "", + Schema: schema, + }) + } privileges, err := convertPrivilegeKinds(auth.PrivilegeObject_TABLE, node.Privileges) if err != nil { return nil, err } grantTable = &pgnodes.GrantTable{ - Privileges: privileges, - Tables: tables, - AllTablesInSchemas: nil, + Privileges: privileges, + Tables: tables, + } + case privilege.Schema: + privileges, err := convertPrivilegeKinds(auth.PrivilegeObject_SCHEMA, node.Privileges) + if err != nil { + return nil, err + } + grantSchema = &pgnodes.GrantSchema{ + Privileges: privileges, + Schemas: node.Targets.Names, + } + case privilege.Database: + privileges, err := convertPrivilegeKinds(auth.PrivilegeObject_DATABASE, node.Privileges) + if err != nil { + return nil, err + } + grantDatabase = &pgnodes.GrantDatabase{ + Privileges: privileges, + Databases: node.Targets.Databases.ToStrings(), } default: return nil, fmt.Errorf("this form of GRANT is not yet supported") @@ -70,6 +98,9 @@ func nodeGrant(ctx *Context, node *tree.Grant) (vitess.Statement, error) { return vitess.InjectedStatement{ Statement: &pgnodes.Grant{ GrantTable: grantTable, + GrantSchema: grantSchema, + GrantDatabase: grantDatabase, + GrantRole: nil, ToRoles: node.Grantees, WithGrantOption: node.WithGrantOption, GrantedBy: node.GrantedBy, diff --git a/server/ast/grant_role.go b/server/ast/grant_role.go index f0c1b62eb9..9c7e5523ab 100644 --- a/server/ast/grant_role.go +++ b/server/ast/grant_role.go @@ -15,7 +15,7 @@ package ast import ( - "fmt" + pgnodes "github.com/dolthub/doltgresql/server/node" vitess "github.com/dolthub/vitess/go/vt/sqlparser" @@ -27,5 +27,15 @@ func nodeGrantRole(ctx *Context, node *tree.GrantRole) (vitess.Statement, error) if node == nil { return nil, nil } - return nil, fmt.Errorf("GRANT ROLE is not yet supported") + return vitess.InjectedStatement{ + Statement: &pgnodes.Grant{ + GrantRole: &pgnodes.GrantRole{ + Groups: node.Roles.ToStrings(), + }, + ToRoles: node.Members, + WithGrantOption: len(node.WithOption) > 0, + GrantedBy: node.GrantedBy, + }, + Children: nil, + }, nil } diff --git a/server/ast/insert.go b/server/ast/insert.go index 3b87ae7d9d..6bf8407d54 100644 --- a/server/ast/insert.go +++ b/server/ast/insert.go @@ -105,8 +105,8 @@ func nodeInsert(ctx *Context, node *tree.Insert) (*vitess.Insert, error) { OnDup: onDuplicate, Auth: vitess.AuthInformation{ AuthType: auth.AuthType_INSERT, - TargetType: auth.AuthTargetType_SingleTableIdentifier, - TargetNames: []string{tableName.SchemaQualifier.String(), tableName.Name.String()}, + TargetType: auth.AuthTargetType_TableIdentifiers, + TargetNames: []string{tableName.DbQualifier.String(), tableName.SchemaQualifier.String(), tableName.Name.String()}, }, }, nil } diff --git a/server/ast/revoke.go b/server/ast/revoke.go index 84c5c47214..1bf0adc52b 100644 --- a/server/ast/revoke.go +++ b/server/ast/revoke.go @@ -32,9 +32,11 @@ func nodeRevoke(ctx *Context, node *tree.Revoke) (vitess.Statement, error) { return nil, nil } var revokeTable *pgnodes.RevokeTable + var revokeSchema *pgnodes.RevokeSchema + var revokeDatabase *pgnodes.RevokeDatabase switch node.Targets.TargetType { case privilege.Table: - tables := make([]doltdb.TableName, len(node.Targets.Tables)) + tables := make([]doltdb.TableName, len(node.Targets.Tables)+len(node.Targets.InSchema)) for i, table := range node.Targets.Tables { normalizedTable, err := table.NormalizeTablePattern() if err != nil { @@ -50,19 +52,45 @@ func nodeRevoke(ctx *Context, node *tree.Revoke) (vitess.Statement, error) { Schema: string(normalizedTable.SchemaName), } case *tree.AllTablesSelector: - return nil, fmt.Errorf("selecting all tables in a schema is not yet supported") + tables[i] = doltdb.TableName{ + Name: "", + Schema: string(normalizedTable.SchemaName), + } default: return nil, fmt.Errorf(`unexpected table type in REVOKE: %T`, normalizedTable) } } + for _, schema := range node.Targets.InSchema { + tables = append(tables, doltdb.TableName{ + Name: "", + Schema: schema, + }) + } privileges, err := convertPrivilegeKinds(auth.PrivilegeObject_TABLE, node.Privileges) if err != nil { return nil, err } revokeTable = &pgnodes.RevokeTable{ - Privileges: privileges, - Tables: tables, - AllTablesInSchemas: nil, + Privileges: privileges, + Tables: tables, + } + case privilege.Schema: + privileges, err := convertPrivilegeKinds(auth.PrivilegeObject_SCHEMA, node.Privileges) + if err != nil { + return nil, err + } + revokeSchema = &pgnodes.RevokeSchema{ + Privileges: privileges, + Schemas: node.Targets.Names, + } + case privilege.Database: + privileges, err := convertPrivilegeKinds(auth.PrivilegeObject_DATABASE, node.Privileges) + if err != nil { + return nil, err + } + revokeDatabase = &pgnodes.RevokeDatabase{ + Privileges: privileges, + Databases: node.Targets.Databases.ToStrings(), } default: return nil, fmt.Errorf("this form of REVOKE is not yet supported") @@ -70,6 +98,9 @@ func nodeRevoke(ctx *Context, node *tree.Revoke) (vitess.Statement, error) { return vitess.InjectedStatement{ Statement: &pgnodes.Revoke{ RevokeTable: revokeTable, + RevokeSchema: revokeSchema, + RevokeDatabase: revokeDatabase, + RevokeRole: nil, FromRoles: node.Grantees, GrantedBy: node.GrantedBy, GrantOptionFor: node.GrantOptionFor, diff --git a/server/ast/revoke_role.go b/server/ast/revoke_role.go index 94ca24eb6d..62b0c5f48d 100644 --- a/server/ast/revoke_role.go +++ b/server/ast/revoke_role.go @@ -15,7 +15,7 @@ package ast import ( - "fmt" + pgnodes "github.com/dolthub/doltgresql/server/node" vitess "github.com/dolthub/vitess/go/vt/sqlparser" @@ -27,5 +27,16 @@ func nodeRevokeRole(ctx *Context, node *tree.RevokeRole) (vitess.Statement, erro if node == nil { return nil, nil } - return nil, fmt.Errorf("REVOKE ROLE is not yet supported") + return vitess.InjectedStatement{ + Statement: &pgnodes.Revoke{ + RevokeRole: &pgnodes.RevokeRole{ + Groups: node.Roles.ToStrings(), + }, + FromRoles: node.Members, + GrantedBy: node.GrantedBy, + GrantOptionFor: len(node.Option) > 0, + Cascade: node.DropBehavior == tree.DropCascade, + }, + Children: nil, + }, nil } diff --git a/server/ast/table_expr.go b/server/ast/table_expr.go index ddffe0dd96..f73fdb6bd8 100644 --- a/server/ast/table_expr.go +++ b/server/ast/table_expr.go @@ -125,8 +125,8 @@ func nodeTableExpr(ctx *Context, node tree.TableExpr) (vitess.TableExpr, error) Expr: tableName, Auth: vitess.AuthInformation{ AuthType: ctx.Auth().PeekAuthType(), - TargetType: auth.AuthTargetType_SingleTableIdentifier, - TargetNames: []string{tableName.SchemaQualifier.String(), tableName.Name.String()}, + TargetType: auth.AuthTargetType_TableIdentifiers, + TargetNames: []string{tableName.DbQualifier.String(), tableName.SchemaQualifier.String(), tableName.Name.String()}, }, }, nil case *tree.TableRef: @@ -140,8 +140,8 @@ func nodeTableExpr(ctx *Context, node tree.TableExpr) (vitess.TableExpr, error) Expr: tableName, Auth: vitess.AuthInformation{ AuthType: ctx.Auth().PeekAuthType(), - TargetType: auth.AuthTargetType_SingleTableIdentifier, - TargetNames: []string{tableName.SchemaQualifier.String(), tableName.Name.String()}, + TargetType: auth.AuthTargetType_TableIdentifiers, + TargetNames: []string{tableName.DbQualifier.String(), tableName.SchemaQualifier.String(), tableName.Name.String()}, }, }, nil default: diff --git a/server/ast/truncate.go b/server/ast/truncate.go index 0101328291..1961457b0e 100644 --- a/server/ast/truncate.go +++ b/server/ast/truncate.go @@ -48,8 +48,8 @@ func nodeTruncate(ctx *Context, node *tree.Truncate) (*vitess.DDL, error) { Table: tableName, Auth: vitess.AuthInformation{ AuthType: auth.AuthType_TRUNCATE, - TargetType: auth.AuthTargetType_SingleTableIdentifier, - TargetNames: []string{tableName.SchemaQualifier.String(), tableName.Name.String()}, + TargetType: auth.AuthTargetType_TableIdentifiers, + TargetNames: []string{tableName.DbQualifier.String(), tableName.SchemaQualifier.String(), tableName.Name.String()}, }, }, nil } diff --git a/server/ast/with.go b/server/ast/with.go index 8b4551d756..9dc4edd6ee 100644 --- a/server/ast/with.go +++ b/server/ast/with.go @@ -64,7 +64,7 @@ func nodeWith(ctx *Context, node *tree.With) (*vitess.With, error) { return nil, nil } - ctes := make([]vitess.TableExpr, len(node.CTEList)) + ctes := make([]*vitess.CommonTableExpr, len(node.CTEList)) for i, cte := range node.CTEList { var err error ctes[i], err = nodeCTE(ctx, cte) diff --git a/server/auth/auth_handler.go b/server/auth/auth_handler.go index a8f364069e..ae3f0c4d8b 100644 --- a/server/auth/auth_handler.go +++ b/server/auth/auth_handler.go @@ -17,6 +17,7 @@ package auth import ( "errors" "fmt" + "strings" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/go-mysql-server/sql" @@ -100,8 +101,24 @@ func (h *AuthorizationHandler) HandleAuth(ctx *sql.Context, aqs sql.Authorizatio case AuthType_IGNORE: // This means that authorization is being handled elsewhere (such as a child or parent), and should be ignored here return nil + case AuthType_CREATE: + privileges = []Privilege{Privilege_CREATE} case AuthType_DELETE: privileges = []Privilege{Privilege_DELETE} + case AuthType_DROPTABLE: + if len(auth.TargetNames)%3 != 0 { + return fmt.Errorf("table identifiers has an unsupported count: %d", len(auth.TargetNames)) + } + for i := 0; i < len(auth.TargetNames); i += 3 { + // TODO: handle database + if id := HasOwnerAccess(OwnershipKey{ + PrivilegeObject: PrivilegeObject_TABLE, + Schema: auth.TargetNames[i+1], + Name: auth.TargetNames[i+2], + }, state.role.ID()); !id.IsValid() { + return fmt.Errorf("permission denied for table %s", auth.TargetNames[i+2]) + } + } case AuthType_INSERT: privileges = []Privilege{Privilege_INSERT} case AuthType_SELECT: @@ -122,28 +139,73 @@ func (h *AuthorizationHandler) HandleAuth(ctx *sql.Context, aqs sql.Authorizatio switch auth.TargetType { case AuthTargetType_Ignore: // This means that the AuthType did not need a TargetType, so we can safely ignore it - case AuthTargetType_SingleTableIdentifier: - schemaName, err := core.GetSchemaName(ctx, nil, auth.TargetNames[0]) - if err != nil { - return sql.ErrTableNotFound.New(auth.TargetNames[1]) + case AuthTargetType_DatabaseIdentifiers: + for _, database := range auth.TargetNames { + database = h.dbName(ctx, database) + roleDatabaseKey := DatabasePrivilegeKey{ + Role: state.role.ID(), + Name: database, + } + publicDatabaseKey := DatabasePrivilegeKey{ + Role: state.public.ID(), + Name: database, + } + for _, privilege := range privileges { + if !HasDatabasePrivilege(roleDatabaseKey, privilege) && !HasDatabasePrivilege(publicDatabaseKey, privilege) { + return fmt.Errorf("permission denied for database %s", database) + } + } } - ownerKey := OwnershipKey{ - PrivilegeObject: PrivilegeObject_TABLE, - Schema: schemaName, - Name: auth.TargetNames[1], + case AuthTargetType_SchemaIdentifiers: + if len(auth.TargetNames)%2 != 0 { + return fmt.Errorf("schema identifiers has an unsupported count: %d", len(auth.TargetNames)) } - roleTableKey := TablePrivilegeKey{ - Role: state.role.ID(), - Table: doltdb.TableName{Name: auth.TargetNames[1], Schema: schemaName}, + for i := 0; i < len(auth.TargetNames); i += 2 { + // TODO: handle database + schemaName, err := core.GetSchemaName(ctx, nil, auth.TargetNames[i+1]) + if err != nil { + // If this fails, then there's an issue with the search path. + // This will error later in the process, so we'll pass auth for now. + return nil + } + roleSchemaKey := SchemaPrivilegeKey{ + Role: state.role.ID(), + Schema: schemaName, + } + publicSchemaKey := SchemaPrivilegeKey{ + Role: state.public.ID(), + Schema: schemaName, + } + for _, privilege := range privileges { + if !HasSchemaPrivilege(roleSchemaKey, privilege) && !HasSchemaPrivilege(publicSchemaKey, privilege) { + return fmt.Errorf("permission denied for schema %s", schemaName) + } + } } - publicTableKey := TablePrivilegeKey{ - Role: state.public.ID(), - Table: doltdb.TableName{Name: auth.TargetNames[1], Schema: schemaName}, + case AuthTargetType_TableIdentifiers: + if len(auth.TargetNames)%3 != 0 { + return fmt.Errorf("table identifiers has an unsupported count: %d", len(auth.TargetNames)) } - for _, privilege := range privileges { - if !state.role.IsSuperUser && !IsOwner(ownerKey, state.role.ID()) && - !HasTablePrivilege(roleTableKey, privilege) && !HasTablePrivilege(publicTableKey, privilege) { - return fmt.Errorf("permission denied for table %s", auth.TargetNames[1]) + for i := 0; i < len(auth.TargetNames); i += 3 { + // TODO: handle database + schemaName, err := core.GetSchemaName(ctx, nil, auth.TargetNames[i+1]) + if err != nil { + // If this fails, then there's an issue with the search path. + // This will error later in the process, so we'll pass auth for now. + return nil + } + roleTableKey := TablePrivilegeKey{ + Role: state.role.ID(), + Table: doltdb.TableName{Name: auth.TargetNames[i+2], Schema: schemaName}, + } + publicTableKey := TablePrivilegeKey{ + Role: state.public.ID(), + Table: doltdb.TableName{Name: auth.TargetNames[i+2], Schema: schemaName}, + } + for _, privilege := range privileges { + if !HasTablePrivilege(roleTableKey, privilege) && !HasTablePrivilege(publicTableKey, privilege) { + return fmt.Errorf("permission denied for table %s", auth.TargetNames[i+2]) + } } } case AuthTargetType_TODO: @@ -209,3 +271,14 @@ func (h *AuthorizationHandler) CheckTable(ctx *sql.Context, aqs sql.Authorizatio // TODO: implement this return nil } + +// dbName uses the current database from the context if a database is not specified, otherwise it returns the given +// database name. +func (h *AuthorizationHandler) dbName(ctx *sql.Context, dbName string) string { + if len(dbName) == 0 { + dbName = ctx.GetCurrentDatabase() + } + // Revision databases take the form "dbname/revision", so we must split the revision from the database name + splitDbName := strings.SplitN(dbName, "/", 2) + return splitDbName[0] +} diff --git a/server/auth/auth_information.go b/server/auth/auth_information.go index 927f40f9ef..c7a30cc1da 100644 --- a/server/auth/auth_information.go +++ b/server/auth/auth_information.go @@ -21,6 +21,7 @@ const ( AuthType_CONNECT = "CONNECT" AuthType_CREATE = "CREATE" AuthType_DELETE = "DELETE" + AuthType_DROPTABLE = "DROPTABLE" AuthType_EXECUTE = "EXECUTE" AuthType_INSERT = "INSERT" AuthType_REFERENCES = "REFERENCES" @@ -35,11 +36,9 @@ const ( // These AuthTargetType_ enums are used as the TargetType in vitess.AuthInformation. const ( - AuthTargetType_Ignore = "IGNORE" - AuthTargetType_DatabaseIdentifiers = "DB_IDENTS" - AuthTargetType_Global = "GLOBAL" - AuthTargetType_MultipleTableIdentifiers = "DB_TABLE_IDENTS" - AuthTargetType_SingleTableIdentifier = "DB_TABLE_IDENT" - AuthTargetType_TableColumn = "DB_TABLE_COLUMN_IDENT" - AuthTargetType_TODO = "TODO" + AuthTargetType_Ignore = "IGNORE" + AuthTargetType_DatabaseIdentifiers = "DB_IDENTS" + AuthTargetType_SchemaIdentifiers = "DB_SCH_IDENTS" + AuthTargetType_TableIdentifiers = "DB_SCH_TABLE_IDENTS" + AuthTargetType_TODO = "TODO" ) diff --git a/server/auth/database.go b/server/auth/database.go index 3122bb8aaa..858edbfa03 100644 --- a/server/auth/database.go +++ b/server/auth/database.go @@ -36,16 +36,24 @@ var ( // Database contains all information pertaining to authorization and privileges. This is a global structure that is // shared between all branches. type Database struct { - rolesByName map[string]RoleID - rolesByID map[RoleID]Role - ownership *Ownership - tablePrivileges *TablePrivileges + rolesByName map[string]RoleID + rolesByID map[RoleID]Role + ownership *Ownership + databasePrivileges *DatabasePrivileges + schemaPrivileges *SchemaPrivileges + tablePrivileges *TablePrivileges + roleMembership *RoleMembership } // ClearDatabase clears the internal database, leaving only the default users. This is primarily for use by tests. func ClearDatabase() { clear(globalDatabase.rolesByName) clear(globalDatabase.rolesByID) + clear(globalDatabase.ownership.Data) + clear(globalDatabase.databasePrivileges.Data) + clear(globalDatabase.schemaPrivileges.Data) + clear(globalDatabase.tablePrivileges.Data) + clear(globalDatabase.roleMembership.Data) dbInitDefault() } @@ -54,7 +62,7 @@ func DropRole(name string) { if roleID, ok := globalDatabase.rolesByName[name]; ok { delete(globalDatabase.rolesByName, name) delete(globalDatabase.rolesByID, roleID) - + // TODO: remove from ownership, schema privileges, table privileges, and role membership } } @@ -99,6 +107,11 @@ func SetRole(role Role) { globalDatabase.rolesByID[role.ID()] = role } +// IsSuperUser returns whether the given role is a SUPERUSER. +func IsSuperUser(role RoleID) bool { + return globalDatabase.rolesByID[role].IsSuperUser +} + // LockRead takes an anonymous function and runs it while using a read lock. This ensures that the lock is automatically // released once the function finishes. func LockRead(f func()) { @@ -119,10 +132,13 @@ func LockWrite(f func()) { // terribly wrong. func dbInit(dEnv *env.DoltEnv) { globalDatabase = Database{ - rolesByName: make(map[string]RoleID), - rolesByID: make(map[RoleID]Role), - ownership: NewOwnership(), - tablePrivileges: NewTablePrivileges(), + rolesByName: make(map[string]RoleID), + rolesByID: make(map[RoleID]Role), + ownership: NewOwnership(), + databasePrivileges: NewDatabasePrivileges(), + schemaPrivileges: NewSchemaPrivileges(), + tablePrivileges: NewTablePrivileges(), + roleMembership: NewRoleMembership(), } globalLock = &sync.RWMutex{} if dEnv != nil { diff --git a/server/auth/database_privileges.go b/server/auth/database_privileges.go new file mode 100644 index 0000000000..1352d2df62 --- /dev/null +++ b/server/auth/database_privileges.go @@ -0,0 +1,218 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "github.com/dolthub/doltgresql/utils" +) + +// DatabasePrivileges contains the privileges given to a role on a database. +type DatabasePrivileges struct { + Data map[DatabasePrivilegeKey]DatabasePrivilegeValue +} + +// DatabasePrivilegeKey points to a specific database object. +type DatabasePrivilegeKey struct { + Role RoleID + Name string +} + +// DatabasePrivilegeValue is the value associated with the DatabasePrivilegeKey. +type DatabasePrivilegeValue struct { + Key DatabasePrivilegeKey + Privileges map[Privilege]map[GrantedPrivilege]bool +} + +// NewDatabasePrivileges returns a new *DatabasePrivileges. +func NewDatabasePrivileges() *DatabasePrivileges { + return &DatabasePrivileges{make(map[DatabasePrivilegeKey]DatabasePrivilegeValue)} +} + +// AddDatabasePrivilege adds the given database privilege to the global database. +func AddDatabasePrivilege(key DatabasePrivilegeKey, privilege GrantedPrivilege, withGrantOption bool) { + databasePrivilegeValue, ok := globalDatabase.databasePrivileges.Data[key] + if !ok { + databasePrivilegeValue = DatabasePrivilegeValue{ + Key: key, + Privileges: make(map[Privilege]map[GrantedPrivilege]bool), + } + globalDatabase.databasePrivileges.Data[key] = databasePrivilegeValue + } + privilegeMap, ok := databasePrivilegeValue.Privileges[privilege.Privilege] + if !ok { + privilegeMap = make(map[GrantedPrivilege]bool) + databasePrivilegeValue.Privileges[privilege.Privilege] = privilegeMap + } + privilegeMap[privilege] = withGrantOption +} + +// HasDatabasePrivilege checks whether the user has the given privilege on the associated database. +func HasDatabasePrivilege(key DatabasePrivilegeKey, privilege Privilege) bool { + if IsSuperUser(key.Role) || IsOwner(OwnershipKey{ + PrivilegeObject: PrivilegeObject_DATABASE, + Name: key.Name, + }, key.Role) { + return true + } + if databasePrivilegeValue, ok := globalDatabase.databasePrivileges.Data[key]; ok { + if privilegeMap, ok := databasePrivilegeValue.Privileges[privilege]; ok && len(privilegeMap) > 0 { + return true + } + } + for _, group := range GetAllGroupsWithMember(key.Role, true) { + if HasDatabasePrivilege(DatabasePrivilegeKey{ + Role: group, + Name: key.Name, + }, privilege) { + return true + } + } + return false +} + +// HasDatabasePrivilegeGrantOption checks whether the user has WITH GRANT OPTION for the given privilege on the associated +// database. Returns the role that has WITH GRANT OPTION, or an invalid role if WITH GRANT OPTION is not available. +func HasDatabasePrivilegeGrantOption(key DatabasePrivilegeKey, privilege Privilege) RoleID { + ownershipKey := OwnershipKey{ + PrivilegeObject: PrivilegeObject_DATABASE, + Name: key.Name, + } + if IsSuperUser(key.Role) { + owners := GetOwners(ownershipKey) + if len(owners) == 0 { + // This may happen if the privilege file is deleted + return key.Role + } + // Although there may be multiple owners, we'll only return the first one. + // Postgres already allows for non-determinism with multiple membership paths, so this is fine. + return owners[0] + } else if IsOwner(ownershipKey, key.Role) { + return key.Role + } + if databasePrivilegeValue, ok := globalDatabase.databasePrivileges.Data[key]; ok { + if privilegeMap, ok := databasePrivilegeValue.Privileges[privilege]; ok { + for _, withGrantOption := range privilegeMap { + if withGrantOption { + return key.Role + } + } + } + } + for _, group := range GetAllGroupsWithMember(key.Role, true) { + if returnedID := HasDatabasePrivilegeGrantOption(DatabasePrivilegeKey{ + Role: group, + Name: key.Name, + }, privilege); returnedID.IsValid() { + return returnedID + } + } + return 0 +} + +// RemoveDatabasePrivilege removes the privilege from the global database. If `grantOptionOnly` is true, then only the WITH +// GRANT OPTION portion is revoked. If `grantOptionOnly` is false, then the full privilege is removed. If the GrantedBy +// field contains a valid RoleID, then only the privilege associated with that granter is removed. Otherwise, the +// privilege is completely removed for the grantee. +func RemoveDatabasePrivilege(key DatabasePrivilegeKey, privilege GrantedPrivilege, grantOptionOnly bool) { + if databasePrivilegeValue, ok := globalDatabase.databasePrivileges.Data[key]; ok { + if privilegeMap, ok := databasePrivilegeValue.Privileges[privilege.Privilege]; ok { + if grantOptionOnly { + // This is provided when we only want to revoke the WITH GRANT OPTION, and not the privilege itself. + // If a role is provided in GRANTED BY, then we specifically delete the option associated with that role. + // If no role was given, then we'll remove WITH GRANT OPTION from all of the associated roles. + if privilege.GrantedBy.IsValid() { + if _, ok = privilegeMap[privilege]; ok { + privilegeMap[privilege] = false + } + } else { + for privilegeMapKey := range privilegeMap { + privilegeMap[privilegeMapKey] = false + } + } + } else { + // If a role is provided in GRANTED BY, then we specifically delete the privilege associated with that role. + // If no role was given, then we'll delete the privileges granted by all roles. + if privilege.GrantedBy.IsValid() { + delete(privilegeMap, privilege) + } else { + privilegeMap = nil + } + if len(privilegeMap) == 0 { + delete(databasePrivilegeValue.Privileges, privilege.Privilege) + } + } + } + if len(databasePrivilegeValue.Privileges) == 0 { + delete(globalDatabase.databasePrivileges.Data, key) + } + } +} + +// serialize writes the DatabasePrivileges to the given writer. +func (sp *DatabasePrivileges) serialize(writer *utils.Writer) { + // Version 0 + // Write the total number of values + writer.Uint64(uint64(len(sp.Data))) + for _, value := range sp.Data { + // Write the key + writer.Uint64(uint64(value.Key.Role)) + writer.String(value.Key.Name) + // Write the total number of privileges + writer.Uint64(uint64(len(value.Privileges))) + for privilege, privilegeMap := range value.Privileges { + writer.String(string(privilege)) + // Write the number of granted privileges + writer.Uint32(uint32(len(privilegeMap))) + for grantedPrivilege, withGrantOption := range privilegeMap { + writer.Uint64(uint64(grantedPrivilege.GrantedBy)) + writer.Bool(withGrantOption) + } + } + } +} + +// deserialize reads the DatabasePrivileges from the given reader. +func (sp *DatabasePrivileges) deserialize(version uint32, reader *utils.Reader) { + sp.Data = make(map[DatabasePrivilegeKey]DatabasePrivilegeValue) + switch version { + case 0: + // Read the total number of values + dataCount := reader.Uint64() + for dataIdx := uint64(0); dataIdx < dataCount; dataIdx++ { + // Read the key + spv := DatabasePrivilegeValue{Privileges: make(map[Privilege]map[GrantedPrivilege]bool)} + spv.Key.Role = RoleID(reader.Uint64()) + spv.Key.Name = reader.String() + // Read the total number of privileges + privilegeCount := reader.Uint64() + for privilegeIdx := uint64(0); privilegeIdx < privilegeCount; privilegeIdx++ { + privilege := Privilege(reader.String()) + // Read the number of granted privileges + grantedCount := reader.Uint32() + grantedMap := make(map[GrantedPrivilege]bool) + for grantedIdx := uint32(0); grantedIdx < grantedCount; grantedIdx++ { + grantedPrivilege := GrantedPrivilege{} + grantedPrivilege.Privilege = privilege + grantedPrivilege.GrantedBy = RoleID(reader.Uint64()) + grantedMap[grantedPrivilege] = reader.Bool() + } + spv.Privileges[privilege] = grantedMap + } + sp.Data[spv.Key] = spv + } + default: + panic("unexpected version in DatabasePrivileges") + } +} diff --git a/server/auth/ownership.go b/server/auth/ownership.go index 6267640432..3aa2d1b359 100644 --- a/server/auth/ownership.go +++ b/server/auth/ownership.go @@ -37,6 +37,7 @@ func NewOwnership() *Ownership { // AddOwner adds the given role as an owner to the global database. func AddOwner(key OwnershipKey, role RoleID) { + key.normalize() ownerMap, ok := globalDatabase.ownership.Data[key] if !ok { ownerMap = make(map[RoleID]struct{}) @@ -47,14 +48,16 @@ func AddOwner(key OwnershipKey, role RoleID) { // GetOwners returns all owners matching the given key. func GetOwners(key OwnershipKey) []RoleID { + key.normalize() if ownerMap, ok := globalDatabase.ownership.Data[key]; ok { return utils.GetMapKeysSorted(ownerMap) } return nil } -// IsOwner returns whether the given owner has an entry for the key. +// IsOwner returns whether the given role is an owner for the key. func IsOwner(key OwnershipKey, role RoleID) bool { + key.normalize() if ownerMap, ok := globalDatabase.ownership.Data[key]; ok { _, ok = ownerMap[role] return ok @@ -62,8 +65,33 @@ func IsOwner(key OwnershipKey, role RoleID) bool { return false } +// HasOwnerAccess returns whether the given role has access to the ownership of an object, along with the ID of the true +// owner (which may be the same as the given role). +func HasOwnerAccess(key OwnershipKey, role RoleID) RoleID { + if IsSuperUser(role) { + owners := GetOwners(key) + if len(owners) == 0 { + // This may happen if the privilege file is deleted + return role + } + // Although there may be multiple owners, we'll only return the first one. + // Postgres already allows for non-determinism with multiple membership paths, so this is fine. + return owners[0] + } + if IsOwner(key, role) { + return role + } + for _, group := range GetAllGroupsWithMember(role, true) { + if returnedID := HasOwnerAccess(key, group); returnedID.IsValid() { + return returnedID + } + } + return 0 +} + // RemoveOwner removes the role as an owner from the global database. func RemoveOwner(key OwnershipKey, role RoleID) { + key.normalize() if ownerMap, ok := globalDatabase.ownership.Data[key]; ok { delete(ownerMap, role) if len(ownerMap) == 0 { @@ -72,6 +100,17 @@ func RemoveOwner(key OwnershipKey, role RoleID) { } } +// normalize accounts for and corrects any potential variation for specific object types. +func (key *OwnershipKey) normalize() { + if key.PrivilegeObject == PrivilegeObject_SCHEMA { + if len(key.Schema) == 0 { + key.Schema = key.Name + } else if len(key.Name) == 0 { + key.Name = key.Schema + } + } +} + // serialize writes the Ownership to the given writer. func (ownership *Ownership) serialize(writer *utils.Writer) { // Version 0 diff --git a/server/auth/role_membership.go b/server/auth/role_membership.go new file mode 100644 index 0000000000..e4fd627607 --- /dev/null +++ b/server/auth/role_membership.go @@ -0,0 +1,167 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import "github.com/dolthub/doltgresql/utils" + +// RoleMembership contains all roles that have been granted to other roles. +type RoleMembership struct { + Data map[RoleID]map[RoleID]RoleMembershipValue +} + +// RoleMembershipValue contains specific membership information between two roles. +type RoleMembershipValue struct { + Member RoleID + Group RoleID + WithAdminOption bool + GrantedBy RoleID +} + +// NewRoleMembership returns a new *RoleMembership. +func NewRoleMembership() *RoleMembership { + return &RoleMembership{ + Data: make(map[RoleID]map[RoleID]RoleMembershipValue), + } +} + +// AddMemberToGroup adds the member role to the group role. +func AddMemberToGroup(member RoleID, group RoleID, withAdminOption bool, grantedBy RoleID) { + // We'll perform a sanity check for circular membership. This should be done before this call is made, but since we + // make assumptions that circular relationships are forbidden (which could lead to infinite loops otherwise), we + // enforce it here too. + if groupID, _, _ := IsRoleAMember(group, member); (groupID.IsValid() || member == group) && !globalDatabase.rolesByID[group].IsSuperUser { + panic("missing validation to prevent circular role relationships") + } + groupMap, ok := globalDatabase.roleMembership.Data[member] + if !ok { + groupMap = make(map[RoleID]RoleMembershipValue) + globalDatabase.roleMembership.Data[member] = groupMap + } + groupMap[group] = RoleMembershipValue{ + Member: member, + Group: group, + WithAdminOption: withAdminOption, + GrantedBy: grantedBy, + } +} + +// IsRoleAMember returns whether the given role is a member of the group by returning the group's ID. Also returns +// whether the member was granted WITH ADMIN OPTION, allowing it to grant membership to the group to other roles. A +// member does not automatically have ADMIN OPTION on itself, therefore this check must be performed. +func IsRoleAMember(member RoleID, group RoleID) (groupID RoleID, inheritsPrivileges bool, hasWithAdminOption bool) { + // If the member and group are the same, then we only check for SUPERUSER status to allow WITH ADMIN OPTION + if member == group { + return group, true, globalDatabase.rolesByID[member].IsSuperUser + } + // Postgres does not allow for circular role membership, so we can recursively check without worry: + // https://www.postgresql.org/docs/15/catalog-pg-auth-members.html + if groupMap, ok := globalDatabase.roleMembership.Data[member]; ok { + for _, value := range groupMap { + if value.Group == group { + return group, globalDatabase.rolesByID[member].InheritPrivileges, value.WithAdminOption + } + // This recursively walks through memberships + if groupID, _, hasWithAdminOption = IsRoleAMember(value.Group, group); groupID.IsValid() { + return groupID, globalDatabase.rolesByID[member].InheritPrivileges, hasWithAdminOption + } + } + } + // A SUPERUSER has access to everything, and therefore functions as though it's a member of every group + if globalDatabase.rolesByID[member].IsSuperUser { + return group, true, true + } + return 0, false, false +} + +// GetAllGroupsWithMember returns every group that the role is a direct member of. This can also filter by groups that +// the member has privilege access on. +func GetAllGroupsWithMember(member RoleID, inheritsPrivilegesOnly bool) []RoleID { + memberRole, ok := globalDatabase.rolesByID[member] + if !ok || !memberRole.InheritPrivileges { + return nil + } + groupMap := globalDatabase.roleMembership.Data[member] + groups := make([]RoleID, 0, len(groupMap)) + for groupID := range groupMap { + groups = append(groups, groupID) + } + return groups +} + +// RemoveMemberFromGroup removes the member from the group. If `adminOptionOnly` is true, then only the WITH ADMIN +// OPTION portion is revoked. If `adminOptionOnly` is false, then the member is fully is removed. +func RemoveMemberFromGroup(member RoleID, group RoleID, adminOptionOnly bool) { + if groupMap, ok := globalDatabase.roleMembership.Data[member]; ok { + if adminOptionOnly { + value := groupMap[group] + value.WithAdminOption = false + groupMap[group] = value + } else { + delete(groupMap, group) + } + if len(groupMap) == 0 { + delete(globalDatabase.roleMembership.Data, member) + } + } +} + +// serialize writes the RoleMembership to the given writer. +func (membership *RoleMembership) serialize(writer *utils.Writer) { + // Version 0 + // Write the total number of members + writer.Uint64(uint64(len(membership.Data))) + for _, groupMap := range membership.Data { + // Write the number of groups + writer.Uint64(uint64(len(groupMap))) + for _, mapValue := range groupMap { + // Write the membership information + writer.Uint64(uint64(mapValue.Member)) + writer.Uint64(uint64(mapValue.Group)) + writer.Bool(mapValue.WithAdminOption) + writer.Uint64(uint64(mapValue.GrantedBy)) + } + } +} + +// deserialize reads the RoleMembership from the given reader. +func (membership *RoleMembership) deserialize(version uint32, reader *utils.Reader) { + membership.Data = make(map[RoleID]map[RoleID]RoleMembershipValue) + switch version { + case 0: + // Read the total number of members + memberCount := reader.Uint64() + for memberIdx := uint64(0); memberIdx < memberCount; memberIdx++ { + // Read the number of groups + groupCount := reader.Uint64() + groupMap := make(map[RoleID]RoleMembershipValue) + var member RoleID + for groupIdx := uint64(0); groupIdx < groupCount; groupIdx++ { + // Read the membership information + value := RoleMembershipValue{} + value.Member = RoleID(reader.Uint64()) + value.Group = RoleID(reader.Uint64()) + value.WithAdminOption = reader.Bool() + value.GrantedBy = RoleID(reader.Uint64()) + // Add the information to the map + groupMap[value.Group] = value + member = value.Member + } + // Add the group map to the data + membership.Data[member] = groupMap + } + default: + panic("unexpected version in RoleMembership") + } +} diff --git a/server/auth/schema_privileges.go b/server/auth/schema_privileges.go new file mode 100644 index 0000000000..40a7f501c2 --- /dev/null +++ b/server/auth/schema_privileges.go @@ -0,0 +1,218 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "github.com/dolthub/doltgresql/utils" +) + +// SchemaPrivileges contains the privileges given to a role on a schema. +type SchemaPrivileges struct { + Data map[SchemaPrivilegeKey]SchemaPrivilegeValue +} + +// SchemaPrivilegeKey points to a specific schema object. +type SchemaPrivilegeKey struct { + Role RoleID + Schema string +} + +// SchemaPrivilegeValue is the value associated with the SchemaPrivilegeKey. +type SchemaPrivilegeValue struct { + Key SchemaPrivilegeKey + Privileges map[Privilege]map[GrantedPrivilege]bool +} + +// NewSchemaPrivileges returns a new *SchemaPrivileges. +func NewSchemaPrivileges() *SchemaPrivileges { + return &SchemaPrivileges{make(map[SchemaPrivilegeKey]SchemaPrivilegeValue)} +} + +// AddSchemaPrivilege adds the given schema privilege to the global database. +func AddSchemaPrivilege(key SchemaPrivilegeKey, privilege GrantedPrivilege, withGrantOption bool) { + schemaPrivilegeValue, ok := globalDatabase.schemaPrivileges.Data[key] + if !ok { + schemaPrivilegeValue = SchemaPrivilegeValue{ + Key: key, + Privileges: make(map[Privilege]map[GrantedPrivilege]bool), + } + globalDatabase.schemaPrivileges.Data[key] = schemaPrivilegeValue + } + privilegeMap, ok := schemaPrivilegeValue.Privileges[privilege.Privilege] + if !ok { + privilegeMap = make(map[GrantedPrivilege]bool) + schemaPrivilegeValue.Privileges[privilege.Privilege] = privilegeMap + } + privilegeMap[privilege] = withGrantOption +} + +// HasSchemaPrivilege checks whether the user has the given privilege on the associated schema. +func HasSchemaPrivilege(key SchemaPrivilegeKey, privilege Privilege) bool { + if IsSuperUser(key.Role) || IsOwner(OwnershipKey{ + PrivilegeObject: PrivilegeObject_SCHEMA, + Schema: key.Schema, + }, key.Role) { + return true + } + if schemaPrivilegeValue, ok := globalDatabase.schemaPrivileges.Data[key]; ok { + if privilegeMap, ok := schemaPrivilegeValue.Privileges[privilege]; ok && len(privilegeMap) > 0 { + return true + } + } + for _, group := range GetAllGroupsWithMember(key.Role, true) { + if HasSchemaPrivilege(SchemaPrivilegeKey{ + Role: group, + Schema: key.Schema, + }, privilege) { + return true + } + } + return false +} + +// HasSchemaPrivilegeGrantOption checks whether the user has WITH GRANT OPTION for the given privilege on the associated +// schema. Returns the role that has WITH GRANT OPTION, or an invalid role if WITH GRANT OPTION is not available. +func HasSchemaPrivilegeGrantOption(key SchemaPrivilegeKey, privilege Privilege) RoleID { + ownershipKey := OwnershipKey{ + PrivilegeObject: PrivilegeObject_SCHEMA, + Schema: key.Schema, + } + if IsSuperUser(key.Role) { + owners := GetOwners(ownershipKey) + if len(owners) == 0 { + // This may happen if the privilege file is deleted + return key.Role + } + // Although there may be multiple owners, we'll only return the first one. + // Postgres already allows for non-determinism with multiple membership paths, so this is fine. + return owners[0] + } else if IsOwner(ownershipKey, key.Role) { + return key.Role + } + if schemaPrivilegeValue, ok := globalDatabase.schemaPrivileges.Data[key]; ok { + if privilegeMap, ok := schemaPrivilegeValue.Privileges[privilege]; ok { + for _, withGrantOption := range privilegeMap { + if withGrantOption { + return key.Role + } + } + } + } + for _, group := range GetAllGroupsWithMember(key.Role, true) { + if returnedID := HasSchemaPrivilegeGrantOption(SchemaPrivilegeKey{ + Role: group, + Schema: key.Schema, + }, privilege); returnedID.IsValid() { + return returnedID + } + } + return 0 +} + +// RemoveSchemaPrivilege removes the privilege from the global database. If `grantOptionOnly` is true, then only the WITH +// GRANT OPTION portion is revoked. If `grantOptionOnly` is false, then the full privilege is removed. If the GrantedBy +// field contains a valid RoleID, then only the privilege associated with that granter is removed. Otherwise, the +// privilege is completely removed for the grantee. +func RemoveSchemaPrivilege(key SchemaPrivilegeKey, privilege GrantedPrivilege, grantOptionOnly bool) { + if schemaPrivilegeValue, ok := globalDatabase.schemaPrivileges.Data[key]; ok { + if privilegeMap, ok := schemaPrivilegeValue.Privileges[privilege.Privilege]; ok { + if grantOptionOnly { + // This is provided when we only want to revoke the WITH GRANT OPTION, and not the privilege itself. + // If a role is provided in GRANTED BY, then we specifically delete the option associated with that role. + // If no role was given, then we'll remove WITH GRANT OPTION from all of the associated roles. + if privilege.GrantedBy.IsValid() { + if _, ok = privilegeMap[privilege]; ok { + privilegeMap[privilege] = false + } + } else { + for privilegeMapKey := range privilegeMap { + privilegeMap[privilegeMapKey] = false + } + } + } else { + // If a role is provided in GRANTED BY, then we specifically delete the privilege associated with that role. + // If no role was given, then we'll delete the privileges granted by all roles. + if privilege.GrantedBy.IsValid() { + delete(privilegeMap, privilege) + } else { + privilegeMap = nil + } + if len(privilegeMap) == 0 { + delete(schemaPrivilegeValue.Privileges, privilege.Privilege) + } + } + } + if len(schemaPrivilegeValue.Privileges) == 0 { + delete(globalDatabase.schemaPrivileges.Data, key) + } + } +} + +// serialize writes the SchemaPrivileges to the given writer. +func (sp *SchemaPrivileges) serialize(writer *utils.Writer) { + // Version 0 + // Write the total number of values + writer.Uint64(uint64(len(sp.Data))) + for _, value := range sp.Data { + // Write the key + writer.Uint64(uint64(value.Key.Role)) + writer.String(value.Key.Schema) + // Write the total number of privileges + writer.Uint64(uint64(len(value.Privileges))) + for privilege, privilegeMap := range value.Privileges { + writer.String(string(privilege)) + // Write the number of granted privileges + writer.Uint32(uint32(len(privilegeMap))) + for grantedPrivilege, withGrantOption := range privilegeMap { + writer.Uint64(uint64(grantedPrivilege.GrantedBy)) + writer.Bool(withGrantOption) + } + } + } +} + +// deserialize reads the SchemaPrivileges from the given reader. +func (sp *SchemaPrivileges) deserialize(version uint32, reader *utils.Reader) { + sp.Data = make(map[SchemaPrivilegeKey]SchemaPrivilegeValue) + switch version { + case 0: + // Read the total number of values + dataCount := reader.Uint64() + for dataIdx := uint64(0); dataIdx < dataCount; dataIdx++ { + // Read the key + spv := SchemaPrivilegeValue{Privileges: make(map[Privilege]map[GrantedPrivilege]bool)} + spv.Key.Role = RoleID(reader.Uint64()) + spv.Key.Schema = reader.String() + // Read the total number of privileges + privilegeCount := reader.Uint64() + for privilegeIdx := uint64(0); privilegeIdx < privilegeCount; privilegeIdx++ { + privilege := Privilege(reader.String()) + // Read the number of granted privileges + grantedCount := reader.Uint32() + grantedMap := make(map[GrantedPrivilege]bool) + for grantedIdx := uint32(0); grantedIdx < grantedCount; grantedIdx++ { + grantedPrivilege := GrantedPrivilege{} + grantedPrivilege.Privilege = privilege + grantedPrivilege.GrantedBy = RoleID(reader.Uint64()) + grantedMap[grantedPrivilege] = reader.Bool() + } + spv.Privileges[privilege] = grantedMap + } + sp.Data[spv.Key] = spv + } + default: + panic("unexpected version in SchemaPrivileges") + } +} diff --git a/server/auth/serialization.go b/server/auth/serialization.go index 2d86724fce..f00dc62e1f 100644 --- a/server/auth/serialization.go +++ b/server/auth/serialization.go @@ -42,8 +42,14 @@ func (db *Database) serialize() []byte { } // Write the ownership db.ownership.serialize(writer) + // Write the database privileges + db.databasePrivileges.serialize(writer) + // Write the schema privileges + db.schemaPrivileges.serialize(writer) // Write the table privileges db.tablePrivileges.serialize(writer) + // Write the role chain + db.roleMembership.serialize(writer) return writer.Data() } @@ -76,7 +82,13 @@ func (db *Database) deserializeV0(reader *utils.Reader) error { } // Read the ownership db.ownership.deserialize(0, reader) + // Read the database privileges + db.databasePrivileges.deserialize(0, reader) + // Read the schema privileges + db.schemaPrivileges.deserialize(0, reader) // Read the table privileges db.tablePrivileges.deserialize(0, reader) + // Read the role chain + db.roleMembership.deserialize(0, reader) return nil } diff --git a/server/auth/table_privileges.go b/server/auth/table_privileges.go index 63a4fd9378..224adb124f 100644 --- a/server/auth/table_privileges.go +++ b/server/auth/table_privileges.go @@ -62,27 +62,87 @@ func AddTablePrivilege(key TablePrivilegeKey, privilege GrantedPrivilege, withGr // HasTablePrivilege checks whether the user has the given privilege on the associated table. func HasTablePrivilege(key TablePrivilegeKey, privilege Privilege) bool { + if IsSuperUser(key.Role) || IsOwner(OwnershipKey{ + PrivilegeObject: PrivilegeObject_TABLE, + Schema: key.Table.Schema, + Name: key.Table.Name, + }, key.Role) { + return true + } + // If a table name was provided, then we also want to search for privileges provided to all tables in the schema + // space. Since those are saved with an empty table name, we can easily do another search by removing the table. + if len(key.Table.Name) > 0 { + if ok := HasTablePrivilege(TablePrivilegeKey{ + Role: key.Role, + Table: doltdb.TableName{Name: "", Schema: key.Table.Schema}, + }, privilege); ok { + return true + } + } if tablePrivilegeValue, ok := globalDatabase.tablePrivileges.Data[key]; ok { - if privilegeMap, ok := tablePrivilegeValue.Privileges[privilege]; ok { - return len(privilegeMap) > 0 + if privilegeMap, ok := tablePrivilegeValue.Privileges[privilege]; ok && len(privilegeMap) > 0 { + return true + } + } + for _, group := range GetAllGroupsWithMember(key.Role, true) { + if HasTablePrivilege(TablePrivilegeKey{ + Role: group, + Table: key.Table, + }, privilege) { + return true } } return false } // HasTablePrivilegeGrantOption checks whether the user has WITH GRANT OPTION for the given privilege on the associated -// table. -func HasTablePrivilegeGrantOption(key TablePrivilegeKey, privilege Privilege) bool { +// table. Returns the role that has WITH GRANT OPTION, or an invalid role if WITH GRANT OPTION is not available. +func HasTablePrivilegeGrantOption(key TablePrivilegeKey, privilege Privilege) RoleID { + ownershipKey := OwnershipKey{ + PrivilegeObject: PrivilegeObject_TABLE, + Schema: key.Table.Schema, + Name: key.Table.Name, + } + if IsSuperUser(key.Role) { + owners := GetOwners(ownershipKey) + if len(owners) == 0 { + // This may happen if the privilege file is deleted + return key.Role + } + // Although there may be multiple owners, we'll only return the first one. + // Postgres already allows for non-determinism with multiple membership paths, so this is fine. + return owners[0] + } else if IsOwner(ownershipKey, key.Role) { + return key.Role + } + // If a table name was provided, then we also want to search for privileges provided to all tables in the schema + // space. Since those are saved with an empty table name, we can easily do another search by removing the table. + if len(key.Table.Name) > 0 { + if returnedID := HasTablePrivilegeGrantOption(TablePrivilegeKey{ + Role: key.Role, + Table: doltdb.TableName{Name: "", Schema: key.Table.Schema}, + }, privilege); returnedID.IsValid() { + return returnedID + } + } if tablePrivilegeValue, ok := globalDatabase.tablePrivileges.Data[key]; ok { if privilegeMap, ok := tablePrivilegeValue.Privileges[privilege]; ok { for _, withGrantOption := range privilegeMap { if withGrantOption { - return true + return key.Role } } } } - return false + for _, group := range GetAllGroupsWithMember(key.Role, true) { + if returnedID := HasTablePrivilegeGrantOption(TablePrivilegeKey{ + Role: group, + Table: key.Table, + }, privilege); returnedID.IsValid() { + return returnedID + } + } + return 0 } // RemoveTablePrivilege removes the privilege from the global database. If `grantOptionOnly` is true, then only the WITH diff --git a/server/node/alter_role.go b/server/node/alter_role.go index cce0ec279b..318e109527 100644 --- a/server/node/alter_role.go +++ b/server/node/alter_role.go @@ -56,22 +56,35 @@ func (c *AlterRole) Resolved() bool { // RowIter implements the interface sql.ExecSourceRel. func (c *AlterRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { + var userRole auth.Role var role auth.Role - var err error auth.LockRead(func() { - if !auth.RoleExists(c.Name) { - err = fmt.Errorf(`role "%s" does not exist`, c.Name) - } else { - role = auth.GetRole(c.Name) - } + userRole = auth.GetRole(ctx.Client().User) + role = auth.GetRole(c.Name) }) - if err != nil { - return nil, err + if !userRole.IsValid() { + return nil, fmt.Errorf(`role "%s" does not exist`, userRole.Name) + } + if !role.IsValid() { + return nil, fmt.Errorf(`role "%s" does not exist`, c.Name) } + if role.IsSuperUser && !userRole.IsSuperUser { + // Only superusers can modify other superusers + // TODO: grab the actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to alter role "%s"`, userRole.Name, role.Name) + } else if !userRole.IsSuperUser && !userRole.CanCreateRoles && role.ID() != userRole.ID() { + // A role may only modify itself if it doesn't have the ability to create roles + // TODO: allow non-role-creating roles to only modify their own password, and grab actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to alter role "%s"`, userRole.Name, role.Name) + } for optionName, optionValue := range c.Options { switch optionName { case "BYPASSRLS": + if !userRole.IsSuperUser { + // TODO: grab the actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to alter role "%s"`, userRole.Name, role.Name) + } role.CanBypassRowLevelSecurity = true case "CONNECTION_LIMIT": role.ConnectionLimit = optionValue.(int32) @@ -84,6 +97,10 @@ func (c *AlterRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { case "LOGIN": role.CanLogin = true case "NOBYPASSRLS": + if !userRole.IsSuperUser { + // TODO: grab the actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to alter role "%s"`, userRole.Name, role.Name) + } role.CanBypassRowLevelSecurity = false case "NOCREATEDB": role.CanCreateDB = false @@ -94,8 +111,16 @@ func (c *AlterRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { case "NOLOGIN": role.CanLogin = false case "NOREPLICATION": + if !userRole.IsSuperUser { + // TODO: grab the actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to alter role "%s"`, userRole.Name, role.Name) + } role.IsReplicationRole = false case "NOSUPERUSER": + if !userRole.IsSuperUser { + // TODO: grab the actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to alter role "%s"`, userRole.Name, role.Name) + } role.IsSuperUser = false case "PASSWORD": password, _ := optionValue.(*string) @@ -109,8 +134,16 @@ func (c *AlterRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { } } case "REPLICATION": + if !userRole.IsSuperUser { + // TODO: grab the actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to alter role "%s"`, userRole.Name, role.Name) + } role.IsReplicationRole = true case "SUPERUSER": + if !userRole.IsSuperUser { + // TODO: grab the actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to alter role "%s"`, userRole.Name, role.Name) + } role.IsSuperUser = true case "VALID_UNTIL": timeString, _ := optionValue.(*string) @@ -128,6 +161,7 @@ func (c *AlterRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { return nil, fmt.Errorf(`unknown role option "%s"`, optionName) } } + var err error auth.LockWrite(func() { auth.SetRole(role) err = auth.PersistChanges() diff --git a/server/node/create_role.go b/server/node/create_role.go index 64b7662737..c66eaad453 100644 --- a/server/node/create_role.go +++ b/server/node/create_role.go @@ -68,10 +68,15 @@ func (c *CreateRole) Resolved() bool { // RowIter implements the interface sql.ExecSourceRel. func (c *CreateRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { + var userRole auth.Role var roleExists bool auth.LockRead(func() { roleExists = auth.RoleExists(c.Name) + userRole = auth.GetRole(ctx.Client().User) }) + if !userRole.IsValid() { + return nil, fmt.Errorf(`role "%s" does not exist`, ctx.Client().User) + } if roleExists { if c.IfNotExists { return sql.RowsToRowIter(), nil @@ -79,6 +84,10 @@ func (c *CreateRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { return nil, fmt.Errorf(`role "%s" already exists`, c.Name) } + if !userRole.IsSuperUser && (!userRole.CanCreateRoles || c.IsSuperUser) { + // TODO: grab the actual error message + return nil, fmt.Errorf(`role "%s" does not have permission to create the role`, userRole.Name) + } var role auth.Role auth.LockWrite(func() { role = auth.CreateDefaultRole(c.Name) diff --git a/server/node/drop_role.go b/server/node/drop_role.go index 3c9fa9bb8e..56e33775d9 100644 --- a/server/node/drop_role.go +++ b/server/node/drop_role.go @@ -52,13 +52,24 @@ func (c *DropRole) Resolved() bool { func (c *DropRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { // TODO: disallow dropping the role if it owns anything // First we'll loop over all of the names to check that they all exist + var userRole auth.Role + var roles []auth.Role var err error auth.LockRead(func() { + userRole = auth.GetRole(ctx.Client().User) for _, roleName := range c.Names { - if !auth.RoleExists(roleName) && !c.IfExists { + role := auth.GetRole(roleName) + if role.IsValid() { + roles = append(roles, role) + } else if !c.IfExists { err = fmt.Errorf(`role "%s" does not exist`, roleName) break } + if !userRole.IsSuperUser && (role.IsSuperUser || !userRole.CanCreateRoles) { + // TODO: grab the actual error message + err = fmt.Errorf(`role "%s" does not have permission to drop role "%s"`, userRole.Name, role.Name) + break + } } }) if err != nil { @@ -66,8 +77,8 @@ func (c *DropRole) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { } // Then we'll loop again, dropping all of the users auth.LockWrite(func() { - for _, roleName := range c.Names { - auth.DropRole(roleName) + for _, role := range roles { + auth.DropRole(role.Name) } err = auth.PersistChanges() }) diff --git a/server/node/grant.go b/server/node/grant.go index 47ae5480da..69a97b5a35 100644 --- a/server/node/grant.go +++ b/server/node/grant.go @@ -30,26 +30,40 @@ import ( // Grant handles all of the GRANT statements. type Grant struct { GrantTable *GrantTable + GrantSchema *GrantSchema + GrantDatabase *GrantDatabase + GrantRole *GrantRole ToRoles []string - WithGrantOption bool // Does not apply to the GRANT TO statement + WithGrantOption bool // This is "WITH ADMIN OPTION" for GrantRole only GrantedBy string } // GrantTable specifically handles the GRANT ... ON TABLE statement. type GrantTable struct { - Privileges []auth.Privilege - Tables []doltdb.TableName - AllTablesInSchemas []string + Privileges []auth.Privilege + Tables []doltdb.TableName } -var _ sql.ExecSourceRel = (*Grant)(nil) -var _ vitess.Injectable = (*Grant)(nil) +// GrantSchema specifically handles the GRANT ... ON SCHEMA statement. +type GrantSchema struct { + Privileges []auth.Privilege + Schemas []string +} -// CheckPrivileges implements the interface sql.ExecSourceRel. -func (g *Grant) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true +// GrantDatabase specifically handles the GRANT ... ON DATABASE statement. +type GrantDatabase struct { + Privileges []auth.Privilege + Databases []string } +// GrantRole specifically handles the GRANT TO statement. +type GrantRole struct { + Groups []string +} + +var _ sql.ExecSourceRel = (*Grant)(nil) +var _ vitess.Injectable = (*Grant)(nil) + // Children implements the interface sql.ExecSourceRel. func (g *Grant) Children() []sql.Node { return nil @@ -71,71 +85,20 @@ func (g *Grant) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { auth.LockWrite(func() { switch { case g.GrantTable != nil: - if len(g.GrantTable.AllTablesInSchemas) > 0 { - err = fmt.Errorf("granting privileges to all tables in the schema is not yet supported") + if err = g.grantTable(ctx); err != nil { return } - roles := make([]auth.Role, len(g.ToRoles)) - // First we'll verify that all of the roles exist - for i, roleName := range g.ToRoles { - roles[i] = auth.GetRole(roleName) - if !roles[i].IsValid() { - err = fmt.Errorf(`role "%s" does not exist`, roleName) - return - } - } - // Then we'll check that the role that is granting the privileges exists - userRole := auth.GetRole(ctx.Client().User) - if !userRole.IsValid() { - err = fmt.Errorf(`role "%s" does not exist`, ctx.Client().User) + case g.GrantSchema != nil: + if err = g.grantSchema(ctx); err != nil { return } - var grantedByID auth.RoleID - if len(g.GrantedBy) != 0 { - // TODO: check the role chain to see if this session's user can assume this role - grantedByRole := auth.GetRole(g.GrantedBy) - if !grantedByRole.IsValid() { - err = fmt.Errorf(`role "%s" does not exist`, g.GrantedBy) - return - } - grantedByID = grantedByRole.ID() - // TODO: check if owners may arbitrarily set the GRANTED BY - if !userRole.IsSuperUser { - err = errors.New("REVOKE currently only allows superusers to set GRANTED BY") - return - } - } else { - grantedByID = userRole.ID() + case g.GrantDatabase != nil: + if err = g.grantDatabase(ctx); err != nil { + return } - // Next we'll assign all of the privileges to each role - for _, role := range roles { - for _, table := range g.GrantTable.Tables { - var schemaName string - schemaName, err = core.GetSchemaName(ctx, nil, table.Schema) - if err != nil { - return - } - key := auth.TablePrivilegeKey{ - Role: role.ID(), - Table: doltdb.TableName{Name: table.Name, Schema: schemaName}, - } - isOwner := auth.IsOwner(auth.OwnershipKey{ - PrivilegeObject: auth.PrivilegeObject_TABLE, - Schema: schemaName, - Name: table.Name, - }, userRole.ID()) - for _, privilege := range g.GrantTable.Privileges { - if !userRole.IsSuperUser && !isOwner && !auth.HasTablePrivilegeGrantOption(key, privilege) { - // TODO: grab the actual error message - err = fmt.Errorf(`role "%s" does not have permission to grant this privilege`, userRole.Name) - return - } - auth.AddTablePrivilege(key, auth.GrantedPrivilege{ - Privilege: privilege, - GrantedBy: grantedByID, - }, g.WithGrantOption) - } - } + case g.GrantRole != nil: + if err = g.grantRole(ctx); err != nil { + return } default: err = fmt.Errorf("GRANT statement is not yet supported") @@ -176,3 +139,154 @@ func (g *Grant) WithResolvedChildren(children []any) (any, error) { } return g, nil } + +// common handles the initial logic for each GRANT statement. +func (g *Grant) common(ctx *sql.Context) ([]auth.Role, auth.Role, error) { + roles := make([]auth.Role, len(g.ToRoles)) + // First we'll verify that all of the roles exist + for i, roleName := range g.ToRoles { + roles[i] = auth.GetRole(roleName) + if !roles[i].IsValid() { + return nil, auth.Role{}, fmt.Errorf(`role "%s" does not exist`, roleName) + } + } + // Then we'll check that the role that is granting the privileges exists + userRole := auth.GetRole(ctx.Client().User) + if !userRole.IsValid() { + return nil, auth.Role{}, fmt.Errorf(`role "%s" does not exist`, ctx.Client().User) + } + if len(g.GrantedBy) != 0 { + grantedByRole := auth.GetRole(g.GrantedBy) + if !grantedByRole.IsValid() { + return nil, auth.Role{}, fmt.Errorf(`role "%s" does not exist`, g.GrantedBy) + } + if userRole.ID() != grantedByRole.ID() { + // TODO: grab the actual error message + return nil, auth.Role{}, errors.New("GRANTED BY may only be set to the calling user") + } + } + return roles, userRole, nil +} + +// grantTable handles *GrantTable from within RowIter. +func (g *Grant) grantTable(ctx *sql.Context) error { + roles, userRole, err := g.common(ctx) + if err != nil { + return err + } + for _, role := range roles { + for _, table := range g.GrantTable.Tables { + schemaName, err := core.GetSchemaName(ctx, nil, table.Schema) + if err != nil { + return err + } + key := auth.TablePrivilegeKey{ + Role: userRole.ID(), + Table: doltdb.TableName{Name: table.Name, Schema: schemaName}, + } + for _, privilege := range g.GrantTable.Privileges { + grantedBy := auth.HasTablePrivilegeGrantOption(key, privilege) + if !grantedBy.IsValid() { + // TODO: grab the actual error message + return fmt.Errorf(`role "%s" does not have permission to grant this privilege`, userRole.Name) + } + auth.AddTablePrivilege(auth.TablePrivilegeKey{ + Role: role.ID(), + Table: doltdb.TableName{Name: table.Name, Schema: schemaName}, + }, auth.GrantedPrivilege{ + Privilege: privilege, + GrantedBy: grantedBy, + }, g.WithGrantOption) + } + } + } + return nil +} + +// grantSchema handles *GrantSchema from within RowIter. +func (g *Grant) grantSchema(ctx *sql.Context) error { + roles, userRole, err := g.common(ctx) + if err != nil { + return err + } + for _, role := range roles { + for _, schema := range g.GrantSchema.Schemas { + key := auth.SchemaPrivilegeKey{ + Role: userRole.ID(), + Schema: schema, + } + for _, privilege := range g.GrantSchema.Privileges { + grantedBy := auth.HasSchemaPrivilegeGrantOption(key, privilege) + if !grantedBy.IsValid() { + // TODO: grab the actual error message + return fmt.Errorf(`role "%s" does not have permission to grant this privilege`, userRole.Name) + } + auth.AddSchemaPrivilege(auth.SchemaPrivilegeKey{ + Role: role.ID(), + Schema: schema, + }, auth.GrantedPrivilege{ + Privilege: privilege, + GrantedBy: grantedBy, + }, g.WithGrantOption) + } + } + } + return nil +} + +// grantDatabase handles *GrantDatabase from within RowIter. +func (g *Grant) grantDatabase(ctx *sql.Context) error { + roles, userRole, err := g.common(ctx) + if err != nil { + return err + } + for _, role := range roles { + for _, database := range g.GrantDatabase.Databases { + key := auth.DatabasePrivilegeKey{ + Role: userRole.ID(), + Name: database, + } + for _, privilege := range g.GrantDatabase.Privileges { + grantedBy := auth.HasDatabasePrivilegeGrantOption(key, privilege) + if !grantedBy.IsValid() { + // TODO: grab the actual error message + return fmt.Errorf(`role "%s" does not have permission to grant this privilege`, userRole.Name) + } + auth.AddDatabasePrivilege(auth.DatabasePrivilegeKey{ + Role: role.ID(), + Name: database, + }, auth.GrantedPrivilege{ + Privilege: privilege, + GrantedBy: grantedBy, + }, g.WithGrantOption) + } + } + } + return nil +} + +// grantRole handles *GrantRole from within RowIter. +func (g *Grant) grantRole(ctx *sql.Context) error { + members, userRole, err := g.common(ctx) + if err != nil { + return err + } + groups := make([]auth.Role, len(g.GrantRole.Groups)) + for i, groupName := range g.GrantRole.Groups { + groups[i] = auth.GetRole(groupName) + if !groups[i].IsValid() { + return fmt.Errorf(`role "%s" does not exist`, groupName) + } + } + for _, member := range members { + for _, group := range groups { + memberGroupID, _, withAdminOption := auth.IsRoleAMember(userRole.ID(), group.ID()) + if !memberGroupID.IsValid() || !withAdminOption { + // TODO: grab the actual error message + return fmt.Errorf(`role "%s" does not have permission to grant role "%s"`, userRole.Name, group.Name) + } + auth.AddMemberToGroup(member.ID(), group.ID(), g.WithGrantOption, memberGroupID) + } + } + return nil +} diff --git a/server/node/revoke.go b/server/node/revoke.go index 17f1a047eb..7d6b5e1b38 100644 --- a/server/node/revoke.go +++ b/server/node/revoke.go @@ -30,27 +30,41 @@ import ( // Revoke handles all of the REVOKE statements. type Revoke struct { RevokeTable *RevokeTable + RevokeSchema *RevokeSchema + RevokeDatabase *RevokeDatabase + RevokeRole *RevokeRole FromRoles []string GrantedBy string - GrantOptionFor bool + GrantOptionFor bool // This is "ADMIN OPTION FOR" for RevokeRole only Cascade bool // When false, represents RESTRICT } // RevokeTable specifically handles the REVOKE ... ON TABLE statement. type RevokeTable struct { - Privileges []auth.Privilege - Tables []doltdb.TableName - AllTablesInSchemas []string + Privileges []auth.Privilege + Tables []doltdb.TableName } -var _ sql.ExecSourceRel = (*Revoke)(nil) -var _ vitess.Injectable = (*Revoke)(nil) +// RevokeSchema specifically handles the REVOKE ... ON SCHEMA statement. +type RevokeSchema struct { + Privileges []auth.Privilege + Schemas []string +} -// CheckPrivileges implements the interface sql.ExecSourceRel. -func (r *Revoke) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true +// RevokeDatabase specifically handles the REVOKE ... ON DATABASE statement. +type RevokeDatabase struct { + Privileges []auth.Privilege + Databases []string } +// RevokeRole specifically handles the REVOKE FROM statement. +type RevokeRole struct { + Groups []string +} + +var _ sql.ExecSourceRel = (*Revoke)(nil) +var _ vitess.Injectable = (*Revoke)(nil) + // Children implements the interface sql.ExecSourceRel. func (r *Revoke) Children() []sql.Node { return nil @@ -76,64 +90,20 @@ func (r *Revoke) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) { auth.LockWrite(func() { switch { case r.RevokeTable != nil: - if len(r.RevokeTable.AllTablesInSchemas) > 0 { - err = fmt.Errorf("revoking privileges to all tables in the schema is not yet supported") + if err = r.revokeTable(ctx); err != nil { return } - roles := make([]auth.Role, len(r.FromRoles)) - // First we'll verify that all of the roles exist - for i, roleName := range r.FromRoles { - roles[i] = auth.GetRole(roleName) - if !roles[i].IsValid() { - err = fmt.Errorf(`role "%s" does not exist`, roleName) - return - } - } - // Then we'll check that the role that is revoking the privileges exists - userRole := auth.GetRole(ctx.Client().User) - if !userRole.IsValid() { - err = fmt.Errorf(`role "%s" does not exist`, ctx.Client().User) + case r.RevokeSchema != nil: + if err = r.revokeSchema(ctx); err != nil { return } - var grantedByID auth.RoleID - if len(r.GrantedBy) != 0 { - grantedByRole := auth.GetRole(r.GrantedBy) - if !grantedByRole.IsValid() { - err = fmt.Errorf(`role "%s" does not exist`, r.GrantedBy) - return - } - grantedByID = grantedByRole.ID() + case r.RevokeDatabase != nil: + if err = r.revokeDatabase(ctx); err != nil { + return } - // Next we'll remove the privileges - for _, role := range roles { - for _, table := range r.RevokeTable.Tables { - var schemaName string - schemaName, err = core.GetSchemaName(ctx, nil, table.Schema) - if err != nil { - return - } - key := auth.TablePrivilegeKey{ - Role: role.ID(), - Table: doltdb.TableName{Name: table.Name, Schema: schemaName}, - } - isOwner := auth.IsOwner(auth.OwnershipKey{ - PrivilegeObject: auth.PrivilegeObject_TABLE, - Schema: schemaName, - Name: table.Name, - }, userRole.ID()) - for _, privilege := range r.RevokeTable.Privileges { - // TODO: we don't have to exactly match the GRANTED BY ID, we can also check if it's in the access chain - if !userRole.IsSuperUser && !isOwner && userRole.ID() != grantedByID { - // TODO: grab the actual error message - err = fmt.Errorf(`role "%s" does not have permission to revoke this privilege`, userRole.Name) - return - } - auth.RemoveTablePrivilege(key, auth.GrantedPrivilege{ - Privilege: privilege, - GrantedBy: grantedByID, - }, r.GrantOptionFor) - } - } + case r.RevokeRole != nil: + if err = r.revokeRole(ctx); err != nil { + return } default: err = fmt.Errorf("REVOKE statement is not yet supported") @@ -174,3 +144,154 @@ func (r *Revoke) WithResolvedChildren(children []any) (any, error) { } return r, nil } + +// common handles the initial logic for each REVOKE statement. +func (r *Revoke) common(ctx *sql.Context) ([]auth.Role, auth.Role, auth.RoleID, error) { + roles := make([]auth.Role, len(r.FromRoles)) + // First we'll verify that all of the roles exist + for i, roleName := range r.FromRoles { + roles[i] = auth.GetRole(roleName) + if !roles[i].IsValid() { + return nil, auth.Role{}, 0, fmt.Errorf(`role "%s" does not exist`, roleName) + } + } + // Then we'll check that the role that is revoking the privileges exists + userRole := auth.GetRole(ctx.Client().User) + if !userRole.IsValid() { + return nil, auth.Role{}, 0, fmt.Errorf(`role "%s" does not exist`, ctx.Client().User) + } + var grantedByID auth.RoleID + if len(r.GrantedBy) != 0 { + grantedByRole := auth.GetRole(r.GrantedBy) + if !grantedByRole.IsValid() { + return nil, auth.Role{}, 0, fmt.Errorf(`role "%s" does not exist`, r.GrantedBy) + } + if groupID, _, _ := auth.IsRoleAMember(userRole.ID(), grantedByRole.ID()); !groupID.IsValid() { + // TODO: grab the actual error message + return nil, auth.Role{}, 0, fmt.Errorf(`role "%s" does not have permission to revoke this privilege`, userRole.Name) + } + } else { + grantedByID = userRole.ID() + } + return roles, userRole, grantedByID, nil +} + +// revokeTable handles *RevokeTable from within RowIter. +func (r *Revoke) revokeTable(ctx *sql.Context) error { + roles, userRole, grantedByID, err := r.common(ctx) + if err != nil { + return err + } + for _, role := range roles { + for _, table := range r.RevokeTable.Tables { + schemaName, err := core.GetSchemaName(ctx, nil, table.Schema) + if err != nil { + return err + } + key := auth.TablePrivilegeKey{ + Role: userRole.ID(), + Table: doltdb.TableName{Name: table.Name, Schema: schemaName}, + } + for _, privilege := range r.RevokeTable.Privileges { + if id := auth.HasTablePrivilegeGrantOption(key, privilege); !id.IsValid() { + // TODO: grab the actual error message + return fmt.Errorf(`role "%s" does not have permission to revoke this privilege`, userRole.Name) + } + auth.RemoveTablePrivilege(auth.TablePrivilegeKey{ + Role: role.ID(), + Table: doltdb.TableName{Name: table.Name, Schema: schemaName}, + }, auth.GrantedPrivilege{ + Privilege: privilege, + GrantedBy: grantedByID, + }, r.GrantOptionFor) + } + } + } + return nil +} + +// revokeSchema handles *RevokeSchema from within RowIter. +func (r *Revoke) revokeSchema(ctx *sql.Context) error { + roles, userRole, grantedByID, err := r.common(ctx) + if err != nil { + return err + } + for _, role := range roles { + for _, schema := range r.RevokeSchema.Schemas { + key := auth.SchemaPrivilegeKey{ + Role: userRole.ID(), + Schema: schema, + } + for _, privilege := range r.RevokeTable.Privileges { + if id := auth.HasSchemaPrivilegeGrantOption(key, privilege); !id.IsValid() { + // TODO: grab the actual error message + return fmt.Errorf(`role "%s" does not have permission to revoke this privilege`, userRole.Name) + } + auth.RemoveSchemaPrivilege(auth.SchemaPrivilegeKey{ + Role: role.ID(), + Schema: schema, + }, auth.GrantedPrivilege{ + Privilege: privilege, + GrantedBy: grantedByID, + }, r.GrantOptionFor) + } + } + } + return nil +} + +// revokeDatabase handles *RevokeDatabase from within RowIter. +func (r *Revoke) revokeDatabase(ctx *sql.Context) error { + roles, userRole, grantedByID, err := r.common(ctx) + if err != nil { + return err + } + for _, role := range roles { + for _, databases := range r.RevokeDatabase.Databases { + key := auth.DatabasePrivilegeKey{ + Role: userRole.ID(), + Name: databases, + } + for _, privilege := range r.RevokeDatabase.Privileges { + if id := auth.HasDatabasePrivilegeGrantOption(key, privilege); !id.IsValid() { + // TODO: grab the actual error message + return fmt.Errorf(`role "%s" does not have permission to revoke this privilege`, userRole.Name) + } + auth.RemoveDatabasePrivilege(auth.DatabasePrivilegeKey{ + Role: role.ID(), + Name: databases, + }, auth.GrantedPrivilege{ + Privilege: privilege, + GrantedBy: grantedByID, + }, r.GrantOptionFor) + } + } + } + return nil +} + +// revokeRole handles *RevokeRole from within RowIter. +func (r *Revoke) revokeRole(ctx *sql.Context) error { + members, userRole, _, err := r.common(ctx) + if err != nil { + return err + } + groups := make([]auth.Role, len(r.RevokeRole.Groups)) + for i, groupName := range r.RevokeRole.Groups { + groups[i] = auth.GetRole(groupName) + if !groups[i].IsValid() { + return fmt.Errorf(`role "%s" does not exist`, groupName) + } + } + for _, member := range members { + for _, group := range groups { + memberGroupID, _, withAdminOption := auth.IsRoleAMember(userRole.ID(), group.ID()) + if !memberGroupID.IsValid() || !withAdminOption { + // TODO: grab the actual error message + return fmt.Errorf(`role "%s" does not have permission to revoke role "%s"`, userRole.Name, group.Name) + } + auth.RemoveMemberFromGroup(member.ID(), group.ID(), r.GrantOptionFor) + } + } + return nil +} diff --git a/testing/go/auth_quick_test.go b/testing/go/auth_quick_test.go new file mode 100644 index 0000000000..a485f00b00 --- /dev/null +++ b/testing/go/auth_quick_test.go @@ -0,0 +1,365 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package _go + +import ( + "strings" + "testing" + + "github.com/dolthub/go-mysql-server/sql" +) + +// TestAuthQuick is modeled after the QuickPrivilegeTest in GMS, so please refer to the documentation there: +// https://github.com/dolthub/go-mysql-server/blob/main/enginetest/queries/priv_auth_queries.go +func TestAuthQuick(t *testing.T) { + // Statements that are run before every test (the state that all tests start with): + // CREATE USER tester PASSWORD 'password'; + // CREATE SCHEMA mysch; + // CREATE SCHEMA othersch; + // CREATE TABLE mysch.test (pk BIGINT PRIMARY KEY, v1 BIGINT); + // CREATE TABLE mysch.test2 (pk BIGINT PRIMARY KEY, v1 BIGINT); + // CREATE TABLE othersch.test (pk BIGINT PRIMARY KEY, v1 BIGINT); + // CREATE TABLE othersch.test2 (pk BIGINT PRIMARY KEY, v1 BIGINT); + // INSERT INTO mysch.test VALUES (0, 0), (1, 1); + // INSERT INTO mysch.test2 VALUES (0, 1), (1, 2); + // INSERT INTO othersch.test VALUES (1, 1), (2, 2); + // INSERT INTO othersch.test2 VALUES (1, 1), (2, 2); + type QuickPrivilegeTest struct { + Focus bool + Queries []string + Expected []sql.Row + ExpectedErr string + } + tests := []QuickPrivilegeTest{ + { + Queries: []string{ + "GRANT SELECT ON ALL TABLES IN SCHEMA mysch TO tester;", + "SELECT * FROM mysch.test;", + }, + Expected: []sql.Row{{0, 0}, {1, 1}}, + }, + { + Queries: []string{ + "GRANT SELECT ON ALL TABLES IN SCHEMA mysch TO tester;", + "SELECT * FROM mysch.test2;", + }, + Expected: []sql.Row{{0, 1}, {1, 2}}, + }, + { + Queries: []string{ + "GRANT SELECT ON mysch.test TO tester;", + "SELECT * FROM mysch.test;", + }, + Expected: []sql.Row{{0, 0}, {1, 1}}, + }, + { + Queries: []string{ + "GRANT SELECT ON mysch.test TO tester;", + "SELECT * FROM mysch.test2;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT SELECT ON ALL TABLES IN SCHEMA othersch TO tester;", + "SELECT * FROM mysch.test;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT SELECT ON othersch.test TO tester;", + "SELECT * FROM mysch.test;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT SELECT ON othersch.test TO tester;", + "SELECT * FROM mysch.test;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "CREATE SCHEMA newsch;", + }, + ExpectedErr: "permission denied for database", + }, + { + Queries: []string{ + "GRANT CREATE ON DATABASE postgres TO tester;", + "CREATE SCHEMA newsch;", + }, + }, + { // This isn't supported yet, but it is supposed to fail since tester is not an owner + Queries: []string{ + "GRANT CREATE ON DATABASE postgres TO tester;", + "CREATE SCHEMA newsch;", + "DROP SCHEMA newsch;", + }, + ExpectedErr: "not yet supported", + }, + { + Queries: []string{ + "CREATE TABLE mysch.new_table (pk BIGINT PRIMARY KEY);", + }, + ExpectedErr: "permission denied for schema", + }, + { + Queries: []string{ + "GRANT CREATE ON SCHEMA mysch TO tester;", + "CREATE TABLE mysch.new_table (pk BIGINT PRIMARY KEY);", + }, + }, + { + Queries: []string{ + "CREATE ROLE new_role;", + }, + ExpectedErr: "does not have permission", + }, + { + Queries: []string{ + "ALTER ROLE tester CREATEROLE;", + "CREATE ROLE new_role;", + }, + }, + { + Queries: []string{ + "CREATE USER new_user;", + }, + ExpectedErr: "does not have permission", + }, + { + Queries: []string{ + "ALTER ROLE tester SUPERUSER;", + "CREATE USER new_user;", + }, + }, + { + Queries: []string{ + "CREATE USER new_user;", + "DROP USER new_user;", + }, + ExpectedErr: "does not have permission", + }, + { + Queries: []string{ + "CREATE USER new_user;", + "ALTER ROLE tester CREATEROLE;", + "DROP USER new_user;", + }, + }, + { + Queries: []string{ + "CREATE USER new_user SUPERUSER;", + "ALTER ROLE tester CREATEROLE;", + "DROP USER new_user;", + }, + ExpectedErr: "does not have permission", + }, + { + Queries: []string{ + "CREATE USER new_user SUPERUSER;", + "ALTER ROLE tester SUPERUSER;", + "DROP USER new_user;", + }, + }, + { + Queries: []string{ + "DELETE FROM mysch.test WHERE pk >= 0;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT DELETE ON ALL TABLES IN SCHEMA mysch TO tester;", + "DELETE FROM mysch.test WHERE pk >= 0;", + }, + }, + { + Queries: []string{ + "GRANT DELETE ON mysch.test TO tester;", + "DELETE FROM mysch.test WHERE pk >= 0;", + }, + }, + { + Queries: []string{ + "CREATE USER tester2;", + "GRANT DELETE ON ALL TABLES IN SCHEMA mysch TO tester2;", + "GRANT tester2 TO tester;", + "DELETE FROM mysch.test WHERE pk >= 0;", + }, + }, + { + Queries: []string{ + "SELECT * FROM mysch.test JOIN mysch.test2 ON test.pk = test2.pk;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT SELECT ON mysch.test TO tester;", + "SELECT * FROM mysch.test JOIN mysch.test2 ON test.pk = test2.pk;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT SELECT ON mysch.test2 TO tester;", + "SELECT * FROM mysch.test JOIN mysch.test2 ON test.pk = test2.pk;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT SELECT ON mysch.test TO tester;", + "GRANT SELECT ON mysch.test2 TO tester;", + "SELECT * FROM mysch.test JOIN mysch.test2 ON test.pk = test2.pk;", + }, + Expected: []sql.Row{{0, 0, 0, 1}, {1, 1, 1, 2}}, + }, + { + Queries: []string{ + "CREATE USER tester2;", + "GRANT SELECT ON mysch.test2 TO tester2;", + "GRANT tester2 TO tester;", + "SELECT * FROM mysch.test JOIN mysch.test2 ON test.pk = test2.pk;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "CREATE USER tester2;", + "GRANT SELECT ON mysch.test TO tester2;", + "GRANT SELECT ON mysch.test2 TO tester2;", + "GRANT tester2 TO tester;", + "SELECT * FROM mysch.test JOIN mysch.test2 ON test.pk = test2.pk;", + }, + Expected: []sql.Row{{0, 0, 0, 1}, {1, 1, 1, 2}}, + }, + { + Queries: []string{ + "CREATE TABLE mysch.new_table (pk BIGINT PRIMARY KEY);", + "DROP TABLE mysch.new_table;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA mysch TO tester;", + "CREATE TABLE mysch.new_table (pk BIGINT PRIMARY KEY);", + "DROP TABLE mysch.new_table;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "CREATE TABLE mysch.new_table (pk BIGINT PRIMARY KEY);", + "GRANT postgres TO tester;", + "DROP TABLE mysch.new_table;", + }, + }, + { + Queries: []string{ + "CREATE ROLE new_role;", + "DROP ROLE new_role;", + }, + ExpectedErr: "does not have permission", + }, + { + Queries: []string{ + "ALTER ROLE tester CREATEROLE;", + "CREATE ROLE new_role;", + "DROP ROLE new_role;", + }, + }, + { + Queries: []string{ + "INSERT INTO mysch.test VALUES (9, 9);", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT INSERT ON ALL TABLES IN SCHEMA mysch TO tester;", + "INSERT INTO mysch.test VALUES (9, 9);", + }, + }, + { + Queries: []string{ + "GRANT INSERT ON mysch.test TO tester;", + "INSERT INTO mysch.test VALUES (9, 9);", + }, + }, + { + Queries: []string{ + "UPDATE mysch.test SET v1 = 0;", + }, + ExpectedErr: "permission denied for table", + }, + { + Queries: []string{ + "GRANT UPDATE ON ALL TABLES IN SCHEMA mysch TO tester;", + "UPDATE mysch.test SET v1 = 0;", + }, + }, + { + Queries: []string{ + "GRANT UPDATE ON mysch.test TO tester;", + "UPDATE mysch.test SET v1 = 0;", + }, + }, + } + // Here we'll convert each quick test into a standard test + scriptTests := make([]ScriptTest, len(tests)) + for testIdx, test := range tests { + scriptTests[testIdx] = ScriptTest{ + Name: strings.Join(test.Queries, "\n > "), + Database: "", + SetUpScript: []string{ + "CREATE USER tester PASSWORD 'password';", + "CREATE SCHEMA mysch;", + "CREATE SCHEMA othersch;", + "CREATE TABLE mysch.test (pk BIGINT PRIMARY KEY, v1 BIGINT);", + "CREATE TABLE mysch.test2 (pk BIGINT PRIMARY KEY, v1 BIGINT);", + "CREATE TABLE othersch.test (pk BIGINT PRIMARY KEY, v1 BIGINT);", + "CREATE TABLE othersch.test2 (pk BIGINT PRIMARY KEY, v1 BIGINT);", + "INSERT INTO mysch.test VALUES (0, 0), (1, 1);", + "INSERT INTO mysch.test2 VALUES (0, 1), (1, 2);", + "INSERT INTO othersch.test VALUES (1, 1), (2, 2);", + "INSERT INTO othersch.test2 VALUES (1, 1), (2, 2);", + }, + Assertions: make([]ScriptTestAssertion, len(test.Queries)), + Focus: test.Focus, + } + for queryIdx := 0; queryIdx < len(test.Queries)-1; queryIdx++ { + scriptTests[testIdx].Assertions[queryIdx] = ScriptTestAssertion{ + Query: test.Queries[queryIdx], + SkipResultsCheck: true, + Username: "postgres", + Password: "password", + } + } + scriptTests[testIdx].Assertions[len(test.Queries)-1] = ScriptTestAssertion{ + Query: test.Queries[len(test.Queries)-1], + Expected: test.Expected, + ExpectedErr: test.ExpectedErr, + Username: "tester", + Password: "password", + } + } + RunScripts(t, scriptTests) +} diff --git a/testing/go/auth_test.go b/testing/go/auth_test.go index 38b28b1dff..2a781a2815 100644 --- a/testing/go/auth_test.go +++ b/testing/go/auth_test.go @@ -351,6 +351,8 @@ func TestAuthTests(t *testing.T) { SetUpScript: []string{ `CREATE USER user1 PASSWORD 'a';`, `CREATE USER user2 PASSWORD 'b';`, + `GRANT ALL PRIVILEGES ON SCHEMA public TO user1;`, + `GRANT ALL PRIVILEGES ON SCHEMA public TO user2;`, }, Assertions: []ScriptTestAssertion{ {