From 37e0707e889f0dc3ad4200024c7f595020cb35e9 Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Mon, 2 Sep 2024 09:48:56 +0300 Subject: [PATCH] schemadiff: reject non-deterministic function in new column's default value (#16684) Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/schemadiff/errors.go | 10 ++++ go/vt/schemadiff/table.go | 37 +++++++++++++++ go/vt/schemadiff/table_test.go | 84 ++++++++++++++++++++++++++++++---- 3 files changed, 121 insertions(+), 10 deletions(-) diff --git a/go/vt/schemadiff/errors.go b/go/vt/schemadiff/errors.go index a941c406be0..c938e736206 100644 --- a/go/vt/schemadiff/errors.go +++ b/go/vt/schemadiff/errors.go @@ -487,3 +487,13 @@ type PartitionSpecNonExclusiveError struct { func (e *PartitionSpecNonExclusiveError) Error() string { return fmt.Sprintf("ALTER TABLE on %s, may only have a single partition spec change, and other changes are not allowed. Found spec: %s; and change: %s", sqlescape.EscapeID(e.Table), sqlparser.CanonicalString(e.PartitionSpec), e.ConflictingStatement) } + +type NonDeterministicDefaultError struct { + Table string + Column string + Function string +} + +func (e *NonDeterministicDefaultError) Error() string { + return fmt.Sprintf("column %s.%s default value uses non-deterministic function: %s", sqlescape.EscapeID(e.Table), sqlescape.EscapeID(e.Column), e.Function) +} diff --git a/go/vt/schemadiff/table.go b/go/vt/schemadiff/table.go index c326b2763b3..d73259523b5 100644 --- a/go/vt/schemadiff/table.go +++ b/go/vt/schemadiff/table.go @@ -1731,6 +1731,33 @@ func evaluateColumnReordering(t1SharedColumns, t2SharedColumns []*sqlparser.Colu return minimalColumnReordering } +// This function looks for a non-deterministic function call in the given expression. +// If recurses into all function arguments. +// The known non-deterministic function we handle are: +// - UUID() +// - RAND() +// - SYSDATE() +func findNoNondeterministicFunction(expr sqlparser.Expr) (foundFunction string) { + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *sqlparser.CurTimeFuncExpr: + switch node.Name.Lowered() { + case "sysdate": + foundFunction = node.Name.String() + return false, nil + } + case *sqlparser.FuncExpr: + switch node.Name.Lowered() { + case "uuid", "rand": + foundFunction = node.Name.String() + return false, nil + } + } + return true, nil + }, expr) + return foundFunction +} + // Diff compares this table statement with another table statement, and sees what it takes to // change this table to look like the other table. // It returns an AlterTable statement if changes are found, or nil if not. @@ -1852,6 +1879,16 @@ func (c *CreateTableEntity) diffColumns(alterTable *sqlparser.AlterTable, addColumn := &sqlparser.AddColumns{ Columns: []*sqlparser.ColumnDefinition{t2Col}, } + // See whether this ADD COLUMN has a non-deterministic default value + if t2Col.Type.Options.Default != nil && !t2Col.Type.Options.DefaultLiteral { + if function := findNoNondeterministicFunction(t2Col.Type.Options.Default); function != "" { + return &NonDeterministicDefaultError{ + Table: c.Name(), + Column: t2Col.Name.String(), + Function: function, + } + } + } if t2ColIndex < expectAppendIndex { // This column is added somewhere in between existing columns, not appended at end of column list if t2ColIndex == 0 { diff --git a/go/vt/schemadiff/table_test.go b/go/vt/schemadiff/table_test.go index 389e55f447c..c511343b1d6 100644 --- a/go/vt/schemadiff/table_test.go +++ b/go/vt/schemadiff/table_test.go @@ -28,16 +28,17 @@ import ( func TestCreateTableDiff(t *testing.T) { tt := []struct { - name string - from string - to string - fromName string - toName string - diff string - diffs []string - cdiff string - cdiffs []string - errorMsg string + name string + from string + to string + fromName string + toName string + diff string + diffs []string + cdiff string + cdiffs []string + errorMsg string + // hints: autoinc int rotation int fulltext int @@ -47,6 +48,7 @@ func TestCreateTableDiff(t *testing.T) { algorithm int enumreorder int subsequent int + // textdiffs []string atomicdiffs []string }{ @@ -449,6 +451,68 @@ func TestCreateTableDiff(t *testing.T) { "+ `y` int,", }, }, + { + name: "added column with non deterministic expression, uuid, reject", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (uuid()))", + errorMsg: (&NonDeterministicDefaultError{Table: "t1", Column: "v", Function: "uuid"}).Error(), + }, + { + name: "added column with non deterministic expression, UUID, reject", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (UUID()))", + errorMsg: (&NonDeterministicDefaultError{Table: "t1", Column: "v", Function: "UUID"}).Error(), + }, + { + name: "added column with non deterministic expression, uuid, spacing, reject", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (uuid ()))", + errorMsg: (&NonDeterministicDefaultError{Table: "t1", Column: "v", Function: "uuid"}).Error(), + }, + { + name: "added column with non deterministic expression, uuid, inner, reject", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (left(uuid(),10)))", + errorMsg: (&NonDeterministicDefaultError{Table: "t1", Column: "v", Function: "uuid"}).Error(), + }, + { + name: "added column with non deterministic expression, rand, reject", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (2.0 + rand()))", + errorMsg: (&NonDeterministicDefaultError{Table: "t1", Column: "v", Function: "rand"}).Error(), + }, + { + name: "added column with non deterministic expression, sysdate, reject", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (sysdate()))", + errorMsg: (&NonDeterministicDefaultError{Table: "t1", Column: "v", Function: "sysdate"}).Error(), + }, + { + name: "added column with non deterministic expression, sysdate, reject", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (to_days(sysdate())))", + errorMsg: (&NonDeterministicDefaultError{Table: "t1", Column: "v", Function: "sysdate"}).Error(), + }, + { + name: "added column with deterministic expression, now, reject does not apply", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (now()))", + diff: "alter table t1 add column v varchar(36) not null default (now())", + cdiff: "ALTER TABLE `t1` ADD COLUMN `v` varchar(36) NOT NULL DEFAULT (now())", + textdiffs: []string{ + "+ `v` varchar(36) NOT NULL DEFAULT (now()),", + }, + }, + { + name: "added column with deterministic expression, curdate, reject does not apply", + from: "create table t1 (id int primary key, a int)", + to: "create table t2 (id int primary key, a int, v varchar(36) not null default (to_days(curdate())))", + diff: "alter table t1 add column v varchar(36) not null default (to_days(curdate()))", + cdiff: "ALTER TABLE `t1` ADD COLUMN `v` varchar(36) NOT NULL DEFAULT (to_days(curdate()))", + textdiffs: []string{ + "+ `v` varchar(36) NOT NULL DEFAULT (to_days(curdate())),", + }, + }, // enum { name: "expand enum",