Skip to content

Commit

Permalink
general improvements on CheckSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
kataras committed Sep 18, 2023
1 parent 15ebbdd commit 7a982af
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 24 deletions.
138 changes: 128 additions & 10 deletions db_information.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ func (db *DB) CheckSchema(ctx context.Context) error {
return fmt.Errorf("expected %d tables, got %d", len(tableNames), len(tables))
}

// var fixQueries []string

for _, table := range tables {
tableName := table.Name

Expand All @@ -229,19 +231,25 @@ func (db *DB) CheckSchema(ctx context.Context) error {
}

for _, col := range table.Columns {
column := td.GetColumnByName(col.Name)
column := td.GetColumnByName(col.Name) // get code column.

if column == nil {
return fmt.Errorf("column %q in table %q not found in schema", col.Name, tableName)
}

if column.Unique { // modify it, so checks are correct.
column.UniqueIndex = fmt.Sprintf("%s_%s_key", tableName, column.Name)
column.Unique = false
}

if expected, got := col.FieldTagString(false), column.FieldTagString(false); expected != got {
// if column.Unique { // modify it, so checks are correct.
// column.UniqueIndex = fmt.Sprintf("%s_%s_key", tableName, column.Name)
// column.Unique = false
// }

if expected, got := strings.ToLower(col.FieldTagString(false)), strings.ToLower(column.FieldTagString(false)); expected != got {
// if strings.Contains(expected, "nullable") && !strings.Contains(got, "nullable") {
// // database has nullable, but code doesn't.
// fixQuery := fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s SET NOT NULL;`, tableName, col.Name)
// fixQueries = append(fixQueries, fixQuery)
// } else {
return fmt.Errorf("column %q in table %q has wrong field tag: db:\n%s\nvs code:\n%s", col.Name, tableName, expected, got)
// }
}

if column.Description == "" {
Expand All @@ -250,6 +258,21 @@ func (db *DB) CheckSchema(ctx context.Context) error {
}
}

// Maybe a next feature but we must be very careful, skip it for now and ofc move it to a different developer-driven method:
// if len(fixQueries) > 0 {
// return db.InTransaction(ctx, func(db *DB) error {
// for _, fixQuery := range fixQueries {
// // fmt.Println(fixQuery)
// _, err = db.Exec(ctx, fixQuery)
// if err != nil {
// return err
// }
// }

// return nil
// })
// }

return nil // return nil if no mismatch is found
}

Expand Down Expand Up @@ -424,6 +447,11 @@ func (db *DB) ListColumns(ctx context.Context, tableNames ...string) ([]*desc.Co
return nil, err
}

uniqueIndexes, err := db.ListUniqueIndexes(ctx, tableNames...)
if err != nil {
return nil, err
}

columns := make([]*desc.Column, 0, len(basicInfos))

for _, basicInfo := range basicInfos {
Expand All @@ -436,6 +464,18 @@ func (db *DB) ListColumns(ctx context.Context, tableNames ...string) ([]*desc.Co
}
}

uniqueIndexLoop:
for _, uniqueIndex := range uniqueIndexes {
if uniqueIndex.TableName == column.TableName {
for _, columnName := range uniqueIndex.Columns {
if columnName == column.Name {
column.UniqueIndex = uniqueIndex.IndexName
break uniqueIndexLoop
}
}
}
}

// No need to put index types on these type of columns, postgres manages these.
if column.PrimaryKey || column.Unique || column.UniqueIndex != "" {
column.Index = desc.InvalidIndex
Expand Down Expand Up @@ -494,19 +534,19 @@ FROM
WHERE
schemaname = $1 AND
( CARDINALITY($2::varchar[]) = 0 OR tablename = ANY($2::varchar[]) ) AND
indexdef NOT LIKE '%UNIQUE%'
indexdef NOT LIKE '%UNIQUE%' -- don't collect unique indexes here, they are (or should be) collected in the first part of the query OR by the ListUniqueIndexes.
ORDER BY table_name, column_name;`

/*
table_name column_name constraint_name constraint_type constraint_definition index_type
blog_posts blog_posts_blog_id_fkey i CREATE INDEX blog_posts_blog_id_fkey ON public.blog_posts USING btree (blog_id)
blog_posts blog_posts_blog_id_fkey i CREATE INDEX blog_posts_blog_id_fkey ON public.blog_posts USING btree (blog_id)
blog_posts blog_id blog_posts_blog_id_fkey f FOREIGN KEY (blog_id) REFERENCES blogs(id) ON DELETE CASCADE DEFERRABLE
blog_posts id blog_posts_pkey p PRIMARY KEY (id) btree
blog_posts read_time_minutes blog_posts_read_time_minutes_check c CHECK ((read_time_minutes > 0))
blog_posts source_url uk_blog_post u UNIQUE (title, source_url) btree
blog_posts title uk_blog_post u UNIQUE (title, source_url) btree
blogs id blogs_pkey p PRIMARY KEY (id) btree
customers customers_name_idx i CREATE INDEX customers_name_idx ON public.customers USING btree (name)
customers customers_name_idx i CREATE INDEX customers_name_idx ON public.customers USING btree (name)
customers cognito_user_id customer_unique_idx u UNIQUE (cognito_user_id, email) btree
customers email customer_unique_idx u UNIQUE (cognito_user_id, email) btree
customers id customers_pkey p PRIMARY KEY (id) btree
Expand All @@ -519,6 +559,8 @@ ORDER BY table_name, column_name;`
}
defer rows.Close() // close the rows instance when done

// constraintNames := make(map[string]struct{})

cs := make([]*desc.Constraint, 0) // create an empty slice to store the constraint definitions
for rows.Next() { // loop over the rows returned by the query
var (
Expand All @@ -537,9 +579,85 @@ ORDER BY table_name, column_name;`
return nil, err
}

// if _, exists := constraintNames[c.ConstraintName]; !exists {
// constraintNames[c.ConstraintName] = struct{}{}

c.Build(constraintDefinition)
cs = append(cs, &c)
// }
}

if err = rows.Err(); err != nil {
return nil, err
}

return cs, nil
}

// ListUniqueIndexes returns a list of unique indexes in the database schema by querying the pg_index table and
// filtering the results to only include unique indexes.
func (db *DB) ListUniqueIndexes(ctx context.Context, tableNames ...string) ([]*desc.UniqueIndex, error) {
if tableNames == nil {
tableNames = make([]string, 0)
}

query := `SELECT
-- n.nspname AS schema_name,
t.relname AS table_name,
i.relname AS index_name,
array_agg(a.attname ORDER BY a.attnum) AS index_columns
FROM pg_index p
JOIN pg_class t ON t.oid = p.indrelid -- the table
JOIN pg_class i ON i.oid = p.indexrelid -- the index
JOIN pg_namespace n ON n.oid = t.relnamespace -- the schema
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(p.indkey) -- the columns
WHERE n.nspname = $1
AND ( CARDINALITY($2::varchar[]) = 0 OR t.relname = ANY($2::varchar[]) )
AND p.indisunique -- only unique indexes
AND NOT p.indisprimary -- not primary keys
AND NOT EXISTS ( -- not created by a constraint
SELECT 1 FROM pg_constraint c
WHERE c.conindid = p.indexrelid
)
GROUP BY n.nspname, t.relname, i.relname;`
/*
public customer_allergies customer_allergy {customer_id,allergy_id}
public customer_cheat_foods customer_cheat_food {customer_id,food_id}
public customer_devices customer_devices_unique {customer_id,type}
*/

// Execute the query using db.Query and pass in the search path as a parameter
rows, err := db.Query(ctx, query, db.searchPath, tableNames)
if err != nil {
return nil, err // return nil and the error if the query fails
}
defer rows.Close() // close the rows instance when done

cs := make([]*desc.UniqueIndex, 0) // create an empty slice to store the unique index definitions

for rows.Next() { // loop over the rows returned by the query
var (
tableName string
indexName string
columns []string
)

if err = rows.Scan(
&tableName,
&indexName,
&columns,
); err != nil {
return nil, err
}

c := desc.UniqueIndex{
TableName: tableName,
IndexName: indexName,
Columns: columns,
}

cs = append(cs, &c)

}

if err = rows.Err(); err != nil {
Expand Down
15 changes: 13 additions & 2 deletions desc/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type (
TypeArgument string // an optional argument for the data type, e.g. 255 when Type is "varchar"
PrimaryKey bool // a flag that indicates if the column is a primary key
Identity bool // a flag that indicates if the column is an identity column, e.g. INT GENERATED ALWAYS AS IDENTITY
// Required bool // a flag that indicates if the column is required (not null)
// Required bool // a flag that indicates if the column is required (not null, let's just use the !Nullable)
Default string // an optional default value or sql function for the column
CheckConstraint string // an optional check constraint for the column
Unique bool // a flag that indicates if the column has a unique constraint (postgres automatically adds an index for that single one)
Expand Down Expand Up @@ -144,7 +144,18 @@ func (c *Column) FieldTagString(strict bool) string {
// writeTagProp(b, ",default=%s", nullLiteral)
writeTagProp(b, ",nullable", true)
} else {
writeTagProp(b, ",default=%s", c.Default)
defaultValue := c.Default
if !strict {
// E.g. {}::integer[], we need to cut the ::integer[] part as it's so strict.
// Cut {}::integer[] the :: part.
if names, ok := dataTypeText[c.Type]; ok {
for _, name := range names {
defaultValue = strings.TrimSuffix(defaultValue, "::"+name)
}
}
}

writeTagProp(b, ",default=%s", defaultValue)
}

writeTagProp(b, ",unique", c.Unique)
Expand Down
29 changes: 23 additions & 6 deletions desc/constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ func (c *Constraint) BuildColumn(column *Column) error {
case PrimaryKeyConstraintType:
column.PrimaryKey = true
case UniqueConstraintType:
if len(c.Unique.Columns) == 0 {
// simple unique to itself.
column.Unique = true
} else {
column.UniqueIndex = c.ConstraintName
}
// if len(c.Unique.Columns) == 0 {
// // simple unique to itself.
// column.Unique = true
// } else {
// column.UniqueIndex = c.ConstraintName
// }
column.Unique = true
case CheckConstraintType:
column.CheckConstraint = c.Check.Expression
case ForeignKeyConstraintType:
Expand Down Expand Up @@ -183,6 +184,22 @@ func parseUniqueConstraint(constraintDefinition string) *UniqueConstraint {
}
}

var uniqueIndexConstraintRegexp = regexp.MustCompile(`CREATE UNIQUE INDEX (?P<name>\w+) ON (?P<schema>\w+)\.(?P<table>\w+) USING (?P<method>\w+) \((?P<columns>.*)\)`)

func parseUniqueIndexConstraint(constraintDefinition string) []string {
// Find the submatches in the sql string
matches := uniqueIndexConstraintRegexp.FindStringSubmatch(constraintDefinition)
// Get the names of the subexpressions
names := uniqueIndexConstraintRegexp.SubexpNames()
// Create a map to store the submatches by name
result := make(map[string]string)
for i, match := range matches {
result[names[i]] = match
}
// Return the column names as a slice
return strings.Split(result["columns"], ", ")
}

// CheckConstraint is a type that represents a check constraint.
type CheckConstraint struct {
Expression string
Expand Down
16 changes: 12 additions & 4 deletions desc/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ func ConvertRowsToStruct(td *Table, rows pgx.Rows, valuePtr interface{}) error {
return err // return an error if finding scan targets failed
}

if td.Strict {
for i, t := range scanTargets {
if t == nil {
return fmt.Errorf("struct doesn't have corresponding row field: %s", rows.FieldDescriptions()[i].Name) // return an error if the struct doesn't have a field for a column
for i, t := range scanTargets {
if t == nil {
if td.Strict {
return fmt.Errorf("struct doesn't have corresponding row field: %s (strict check)", rows.FieldDescriptions()[i].Name) // return an error if the struct doesn't have a field for a column
} else {
scanTargets[i] = &nullScanner{}
}
}
}
Expand Down Expand Up @@ -132,6 +134,12 @@ func findScanTargets(dstElemValue reflect.Value, td *Table, fieldDescs []pgconn.
return scanTargets, nil // return the scan targets and nil error
}

type nullScanner struct{}

func (t *nullScanner) Scan(src interface{}) error {
return nil
}

type passwordTextScanner struct {
tableName string
passwordHandler *PasswordHandler
Expand Down
7 changes: 6 additions & 1 deletion desc/struct_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,12 @@ func convertStructFieldToColumnDefinion(tableName string, field reflect.StructFi
}
}

if c.PrimaryKey && !c.Nullable && c.Type == UUID && c.Default == "" {
if c.PrimaryKey &&
!c.Nullable &&
c.Type == UUID &&
c.Default == "" &&
c.ReferenceColumnName == "" /* Note that we don't set default value if referecing to other table (or the same) */ {

c.Default = genRandomUUIDPGCryptoFunction1
// c.AutoGenerated = true
}
Expand Down
9 changes: 9 additions & 0 deletions desc/unique_index.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package desc

// UniqueIndex is a struct that represents a unique index.
// See DB.ListUniqueIndexes method for more.
type UniqueIndex struct {
TableName string // table name
IndexName string // index name
Columns []string // column names.
}
7 changes: 6 additions & 1 deletion schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type Schema struct {
//
// If set to empty then triggers will not be registered automatically.
SetTimestampTriggerName string

// Strict reports whether the schema should be strict on the database side.
// It's enabled by default.
Strict bool
}

// NewSchema creates and returns a new Schema with an initialized struct cache.
Expand All @@ -36,6 +40,7 @@ func NewSchema() *Schema {
UpdatedAtColumnName: "updated_at",
// set the default name for the trigger that sets the "updated_at" column.
SetTimestampTriggerName: "set_timestamp",
Strict: true,
}
}

Expand Down Expand Up @@ -101,7 +106,7 @@ func (s *Schema) MustRegister(tableName string, emptyStructValue any, opts ...Ta
if err != nil { // if there is an error
panic(err) // panic with the error
}
td.SetStrict(true)
td.SetStrict(s.Strict)

return s // return the table definition
}
Expand Down

0 comments on commit 7a982af

Please sign in to comment.