Skip to content

Commit

Permalink
Fix information_schema.columns for databases with schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
tbantle22 committed Jul 17, 2024
1 parent 3d70c26 commit 8d99c2a
Showing 1 changed file with 47 additions and 37 deletions.
84 changes: 47 additions & 37 deletions sql/information_schema/columns_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,17 @@ func (c *ColumnsTable) AllColumns(ctx *sql.Context) (sql.Schema, error) {

var allColumns sql.Schema

for _, db := range c.catalog.AllDatabases(ctx) {
err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) {
databases, err := allDatabases(ctx, c.catalog, false)
if err != nil {
return nil, err
}

for _, db := range databases {
err := sql.DBTableIter(ctx, db.database, func(t sql.Table) (cont bool, err error) {
tableSch := t.Schema()
for i := range tableSch {
newCol := tableSch[i].Copy()
newCol.DatabaseSource = db.Name()
newCol.DatabaseSource = db.database.Name()
allColumns = append(allColumns, newCol)
}
return true, nil
Expand Down Expand Up @@ -205,7 +210,12 @@ func columnsRowIter(ctx *sql.Context, catalog sql.Catalog, allColsWithDefaultVal
}
globalPrivSetMap = getCurrentPrivSetMapForColumn(privSet.ToSlice(), globalPrivSetMap)

for _, db := range catalog.AllDatabases(ctx) {
databases, err := allDatabases(ctx, catalog, false)
if err != nil {
return nil, err
}

for _, db := range databases {
rs, err := getRowsFromDatabase(ctx, db, privSet, globalPrivSetMap, allColsWithDefaultValue)
if err != nil {
return nil, err
Expand All @@ -224,7 +234,7 @@ func columnsRowIter(ctx *sql.Context, catalog sql.Catalog, allColsWithDefaultVal
// getRowFromColumn returns a single row for given column. The arguments passed are used to define all row values.
// These include the current ordinal position, so this column will get the next position number, sql.Column object,
// database name, table name, column key and column privileges information through privileges set for the table.
func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName, tblName, columnKey string, privSetTbl sql.PrivilegeSetTable, privSetMap map[string]struct{}) sql.Row {
func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, catName, schName, tblName, columnKey string, privSetTbl sql.PrivilegeSetTable, privSetMap map[string]struct{}) sql.Row {
var (
ordinalPos = uint32(curOrdPos + 1)
nullable = "NO"
Expand Down Expand Up @@ -279,8 +289,8 @@ func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName,
privileges := strings.Join(curColPrivStr, ",")

return sql.Row{
"def", // table_catalog
dbName, // table_schema
catName, // table_catalog
schName, // table_schema
tblName, // table_name
col.Name, // column_name
ordinalPos, // ordinal_position
Expand All @@ -305,7 +315,7 @@ func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName,
}

// getRowsFromTable returns array of rows for all accessible columns of the given table.
func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb sql.PrivilegeSetDatabase, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
func getRowsFromTable(ctx *sql.Context, db dbWithNames, t sql.Table, privSetDb sql.PrivilegeSetDatabase, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
var rows []sql.Row

privSetTbl := privSetDb.Table(t.Name())
Expand All @@ -317,7 +327,7 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb
}

tblName := t.Name()
for i, col := range schemaForTable(t, db, allColsWithDefaultValue) {
for i, col := range schemaForTable(t, db.database, allColsWithDefaultValue) {
var columnKey string
// Check column PK here first because there are PKs from table implementations that don't implement sql.IndexedTable
if col.PrimaryKey {
Expand All @@ -331,7 +341,7 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb
}
}

r := getRowFromColumn(ctx, i, col, db.Name(), tblName, columnKey, privSetTbl, curPrivSetMap)
r := getRowFromColumn(ctx, i, col, db.catalogName, db.schemaName, tblName, columnKey, privSetTbl, curPrivSetMap)
if r != nil {
rows = append(rows, r)
}
Expand All @@ -341,58 +351,58 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb
}

// getRowsFromViews returns array or rows for columns for all views for given database.
func getRowsFromViews(ctx *sql.Context, db sql.Database) ([]sql.Row, error) {
func getRowsFromViews(ctx *sql.Context, db dbWithNames) ([]sql.Row, error) {
var rows []sql.Row
// TODO: View Definition is lacking information to properly fill out these table
// TODO: Should somehow get reference to table(s) view is referencing
// TODO: Each column that view references should also show up as unique entries as well
views, err := viewsInDatabase(ctx, db)
views, err := viewsInDatabase(ctx, db.database)
if err != nil {
return nil, err
}

for _, view := range views {
rows = append(rows, sql.Row{
"def", // table_catalog
db.Name(), // table_schema
view.Name, // table_name
"", // column_name
uint32(0), // ordinal_position
nil, // column_default
"", // is_nullable
nil, // data_type
nil, // character_maximum_length
nil, // character_octet_length
nil, // numeric_precision
nil, // numeric_scale
nil, // datetime_precision
"", // character_set_name
"", // collation_name
"", // column_type
"", // column_key
"", // extra
"select", // privileges
"", // column_comment
"", // generation_expression
nil, // srs_id
db.catalogName, // table_catalog
db.schemaName, // table_schema
view.Name, // table_name
"", // column_name
uint32(0), // ordinal_position
nil, // column_default
"", // is_nullable
nil, // data_type
nil, // character_maximum_length
nil, // character_octet_length
nil, // numeric_precision
nil, // numeric_scale
nil, // datetime_precision
"", // character_set_name
"", // collation_name
"", // column_type
"", // column_key
"", // extra
"select", // privileges
"", // column_comment
"", // generation_expression
nil, // srs_id
})
}

return rows, nil
}

// getRowsFromDatabase returns array of rows for all accessible columns of accessible table of the given database.
func getRowsFromDatabase(ctx *sql.Context, db sql.Database, privSet sql.PrivilegeSet, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
func getRowsFromDatabase(ctx *sql.Context, db dbWithNames, privSet sql.PrivilegeSet, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
var rows []sql.Row
dbName := db.Name()
dbName := db.database.Name()

privSetDb := privSet.Database(dbName)
curPrivSetMap := getCurrentPrivSetMapForColumn(privSetDb.ToSlice(), privSetMap)
if dbName == sql.InformationSchemaDatabaseName {
curPrivSetMap["select"] = struct{}{}
}

err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) {
err := sql.DBTableIter(ctx, db.database, func(t sql.Table) (cont bool, err error) {
rs, err := getRowsFromTable(ctx, db, t, privSetDb, curPrivSetMap, allColsWithDefaultValue)
if err != nil {
return false, err
Expand Down

0 comments on commit 8d99c2a

Please sign in to comment.