Skip to content

Commit

Permalink
Add unit test for auto table creation
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Lord <[email protected]>
  • Loading branch information
mattlord committed Sep 30, 2024
1 parent 6fc7fdc commit e89bbf7
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 8 deletions.
32 changes: 32 additions & 0 deletions go/vt/vtctl/workflow/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"os"
"reflect"
"regexp"
"slices"
"strings"
Expand All @@ -38,6 +39,7 @@ import (
"vitess.io/vitess/go/vt/mysqlctl/tmutils"
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/memorytopo"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/topotools"
"vitess.io/vitess/go/vt/vtenv"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -269,6 +271,7 @@ type testTMClient struct {
vrQueries map[int][]*queryResult
createVReplicationWorkflowRequests map[uint32]*createVReplicationWorkflowRequestResponse
readVReplicationWorkflowRequests map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest
applySchemaRequests map[uint32]*applySchemaRequestResponse
primaryPositions map[uint32]string
vdiffRequests map[uint32]*vdiffRequestResponse
refreshStateErrors map[uint32]error
Expand All @@ -291,6 +294,7 @@ func newTestTMClient(env *testEnv) *testTMClient {
vrQueries: make(map[int][]*queryResult),
createVReplicationWorkflowRequests: make(map[uint32]*createVReplicationWorkflowRequestResponse),
readVReplicationWorkflowRequests: make(map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest),
applySchemaRequests: make(map[uint32]*applySchemaRequestResponse),
readVReplicationWorkflowsResponses: make(map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse),
primaryPositions: make(map[uint32]string),
vdiffRequests: make(map[uint32]*vdiffRequestResponse),
Expand Down Expand Up @@ -467,8 +471,30 @@ func (tmc *testTMClient) ExecuteFetchAsAllPrivs(ctx context.Context, tablet *top
return nil, nil
}

func (tmc *testTMClient) expectApplySchemaRequest(tabletID uint32, req *applySchemaRequestResponse) {
tmc.mu.Lock()
defer tmc.mu.Unlock()

if tmc.applySchemaRequests == nil {
tmc.applySchemaRequests = make(map[uint32]*applySchemaRequestResponse)
}

tmc.applySchemaRequests[tabletID] = req
}

// Note: ONLY breaks up change.SQL into individual statements and executes it. Does NOT fully implement ApplySchema.
func (tmc *testTMClient) ApplySchema(ctx context.Context, tablet *topodatapb.Tablet, change *tmutils.SchemaChange) (*tabletmanagerdatapb.SchemaChangeResult, error) {
tmc.mu.Lock()
defer tmc.mu.Unlock()

if expect, ok := tmc.applySchemaRequests[tablet.Alias.Uid]; ok {
if !reflect.DeepEqual(change, expect.change) {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected ApplySchema request on tablet %s: got %+v, want %+v",
topoproto.TabletAliasString(tablet.Alias), change, expect.change)
}
return expect.res, expect.err
}

stmts := strings.Split(change.SQL, ";")

for _, stmt := range stmts {
Expand Down Expand Up @@ -497,6 +523,12 @@ type createVReplicationWorkflowRequestResponse struct {
err error
}

type applySchemaRequestResponse struct {
change *tmutils.SchemaChange
res *tabletmanagerdatapb.SchemaChangeResult
err error
}

func (tmc *testTMClient) expectVDiffRequest(tablet *topodatapb.Tablet, vrr *vdiffRequestResponse) {
tmc.mu.Lock()
defer tmc.mu.Unlock()
Expand Down
11 changes: 8 additions & 3 deletions go/vt/vtctl/workflow/traffic_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -1527,12 +1527,14 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s
if err != nil {
return nil, err
}
stmt, err := sqlparser.ParseAndBind(sqlCreateSequenceTable, sqltypes.StringBindVariable(sqlescape.EscapeID(tableName)))
escapedTableName, err := sqlescape.EnsureEscaped(tableName)
if err != nil {
return nil, err
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid table name %s: %v",
tableName, err)
}
stmt := sqlparser.BuildParsedQuery(sqlCreateSequenceTable, escapedTableName)
_, err = ts.ws.tmc.ApplySchema(ctx, primary.Tablet, &tmutils.SchemaChange{
SQL: stmt,
SQL: stmt.Query,
Force: false,
AllowReplication: true,
SQLMode: vreplication.SQLMode,
Expand All @@ -1543,6 +1545,9 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s
tableName, globalKeyspace)
}
if bt := globalVSchema.Tables[sequenceMetadata.backingTableName]; bt == nil {
if globalVSchema.Tables == nil {
globalVSchema.Tables = make(map[string]*vschemapb.Table)
}
globalVSchema.Tables[tableName] = &vschemapb.Table{
Type: vindexes.TypeSequence,
}
Expand Down
91 changes: 86 additions & 5 deletions go/vt/vtctl/workflow/traffic_switcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/vt/mysqlctl/tmutils"
"vitess.io/vitess/go/vt/proto/vschema"
vtctldatapb "vitess.io/vitess/go/vt/proto/vtctldata"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/vtgate/vindexes"
"vitess.io/vitess/go/vt/vttablet/tabletmanager/vreplication"

tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
)
Expand Down Expand Up @@ -74,6 +78,7 @@ func TestGetTargetSequenceMetadata(t *testing.T) {
cell := "cell1"
workflow := "wf1"
table := "`t1`"
tableDDL := "create table t1 (id int not null auto_increment primary key, c1 varchar(10))"
unescapedTable := "t1"
sourceKeyspace := &testKeyspace{
KeyspaceName: "source-ks",
Expand All @@ -91,12 +96,25 @@ func TestGetTargetSequenceMetadata(t *testing.T) {
env := newTestEnv(t, ctx, cell, sourceKeyspace, targetKeyspace)
defer env.close()

env.tmc.schema = map[string]*tabletmanagerdatapb.SchemaDefinition{
unescapedTable: {
TableDefinitions: []*tabletmanagerdatapb.TableDefinition{
{
Name: unescapedTable,
Schema: tableDDL,
},
},
},
}

type testCase struct {
name string
sourceVSchema *vschema.Keyspace
targetVSchema *vschema.Keyspace
want map[string]*sequenceMetadata
err string
name string
sourceVSchema *vschema.Keyspace
targetVSchema *vschema.Keyspace
options *vtctldatapb.WorkflowOptions
want map[string]*sequenceMetadata
expectSourceApplySchemaRequest *applySchemaRequestResponse
err string
}
tests := []testCase{
{
Expand Down Expand Up @@ -152,6 +170,65 @@ func TestGetTargetSequenceMetadata(t *testing.T) {
},
},
},
{
name: "auto_increment replaced with sequence",
sourceVSchema: &vschema.Keyspace{
Vindexes: vindexes,
Tables: map[string]*vschema.Table{}, // Table will be created
},
options: &vtctldatapb.WorkflowOptions{
StripShardedAutoIncrement: vtctldatapb.ShardedAutoIncrementHandling_REPLACE,
GlobalKeyspace: sourceKeyspace.KeyspaceName,
},
expectSourceApplySchemaRequest: &applySchemaRequestResponse{
change: &tmutils.SchemaChange{
SQL: sqlparser.BuildParsedQuery(sqlCreateSequenceTable, fmt.Sprintf("`%s_seq`", unescapedTable)).Query,
Force: false,
AllowReplication: true,
SQLMode: vreplication.SQLMode,
DisableForeignKeyChecks: true,
},
res: &tabletmanagerdatapb.SchemaChangeResult{},
},
targetVSchema: &vschema.Keyspace{
Vindexes: vindexes,
Tables: map[string]*vschema.Table{
table: {
ColumnVindexes: []*vschema.ColumnVindex{
{
Name: "xxhash",
Column: "`my-col`",
},
},
AutoIncrement: &vschema.AutoIncrement{
Column: "my-col",
Sequence: fmt.Sprintf("%s_seq", unescapedTable),
},
},
},
},
want: map[string]*sequenceMetadata{
fmt.Sprintf("%s_seq", unescapedTable): {
backingTableName: fmt.Sprintf("%s_seq", unescapedTable),
backingTableKeyspace: "source-ks",
backingTableDBName: "vt_source-ks",
usingTableName: unescapedTable,
usingTableDBName: "vt_targetks",
usingTableDefinition: &vschema.Table{
ColumnVindexes: []*vschema.ColumnVindex{
{
Column: "my-col",
Name: "xxhash",
},
},
AutoIncrement: &vschema.AutoIncrement{
Column: "my-col",
Sequence: fmt.Sprintf("%s_seq", unescapedTable),
},
},
},
},
},
{
name: "sequences with backticks",
sourceVSchema: &vschema.Keyspace{
Expand Down Expand Up @@ -336,6 +413,9 @@ func TestGetTargetSequenceMetadata(t *testing.T) {
Tablet: tablet,
},
}
if tc.expectSourceApplySchemaRequest != nil {
env.tmc.expectApplySchemaRequest(tablet.Alias.Uid, tc.expectSourceApplySchemaRequest)
}
}
for i, shard := range targetKeyspace.ShardNames {
tablet := env.tablets[targetKeyspace.KeyspaceName][startingTargetTabletUID+(i*tabletUIDStep)]
Expand All @@ -354,6 +434,7 @@ func TestGetTargetSequenceMetadata(t *testing.T) {
targetKeyspace: targetKeyspace.KeyspaceName,
sources: sources,
targets: targets,
options: tc.options,
}
got, err := ts.getTargetSequenceMetadata(ctx)
if tc.err != "" {
Expand Down

0 comments on commit e89bbf7

Please sign in to comment.