diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go index 28b36f74ee8..b839191317a 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go @@ -20,10 +20,12 @@ import ( "context" "encoding/json" "fmt" + "slices" "strings" "sync" "time" + "golang.org/x/exp/maps" "google.golang.org/protobuf/encoding/prototext" "vitess.io/vitess/go/mysql/collations" @@ -194,7 +196,7 @@ func (td *tableDiffer) stopTargetVReplicationStreams(ctx context.Context, dbClie return fmt.Errorf("stream %d has not started on tablet %v", id, td.wd.ct.vde.thisTablet.Alias) } - sourceBytes, err := row["source"].ToBytes() + sourceBytes, err := row[source].ToBytes() if err != nil { return err } @@ -520,8 +522,8 @@ func (td *tableDiffer) diff(ctx context.Context, coreOpts *tabletmanagerdatapb.V } dr.TableName = td.table.Name - sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive, "source") - targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive, "target") + sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive, source) + targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive, target) var sourceRow, lastProcessedRow, targetRow []sqltypes.Value advanceSource := true advanceTarget := true @@ -723,7 +725,7 @@ func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *D return err } - if lastRow == nil { + if len(lastRow) == 0 { query, err = sqlparser.ParseAndBind(sqlUpdateTableNoProgress, sqltypes.Int64BindVariable(dr.ProcessedRows), sqltypes.StringBindVariable(string(rpt)), @@ -736,17 +738,19 @@ func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *D } else { var lastSourcePK, lastTargetPK []byte lastPK := make(map[string]string, 2) - if lastRow != nil { - lastSourcePK, err = td.lastSourcePKFromRow(lastRow) - if err != nil { - return err - } - lastPK["source"] = string(lastSourcePK) - lastTargetPK, err = td.lastTargetPKFromRow(lastRow) + lastTargetPK, err = td.lastPKFromRow(lastRow, td.tablePlan.pkCols) + if err != nil { + return err + } + lastPK[target] = string(lastTargetPK) + if slices.Equal(td.tablePlan.sourcePkCols, td.tablePlan.pkCols) { + lastPK[source] = string(lastTargetPK) + } else { + lastSourcePK, err = td.lastPKFromRow(lastRow, td.tablePlan.sourcePkCols) if err != nil { return err } - lastPK["target"] = string(lastTargetPK) + lastPK[source] = string(lastSourcePK) } if td.wd.opts.CoreOptions.MaxDiffSeconds > 0 { // Update the in-memory lastPK as well so that we can restart the table @@ -761,12 +765,10 @@ func (td *tableDiffer) updateTableProgress(dbClient binlogplayer.DBClient, dr *D } td.lastTargetPK = lastTargetPKPB } - //log.Errorf("DEBUG: updateTableProgress lastPK map: %v", lastPK) lastPKJS, err := json.Marshal(lastPK) if err != nil { return err } - //log.Errorf("DEBUG: updateTableProgress lastPK JSON: %v", lastPKJS) query, err = sqlparser.ParseAndBind(sqlUpdateTableProgress, sqltypes.Int64BindVariable(dr.ProcessedRows), sqltypes.StringBindVariable(string(lastPKJS)), @@ -845,31 +847,11 @@ func updateTableMismatch(dbClient binlogplayer.DBClient, vdiffID int64, table st return nil } -func (td *tableDiffer) lastTargetPKFromRow(row []sqltypes.Value) ([]byte, error) { - pkColCnt := len(td.tablePlan.pkCols) - pkFields := make([]*querypb.Field, pkColCnt) - pkVals := make([]sqltypes.Value, pkColCnt) - for i, colIndex := range td.tablePlan.pkCols { - pkFields[i] = td.tablePlan.table.Fields[colIndex] - pkVals[i] = row[colIndex] - } - buf, err := prototext.Marshal(&querypb.QueryResult{ - Fields: pkFields, - Rows: []*querypb.Row{sqltypes.RowToProto3(pkVals)}, - }) - return buf, err -} - -func (td *tableDiffer) lastSourcePKFromRow(row []sqltypes.Value) ([]byte, error) { - if len(td.tablePlan.sourcePkCols) == 0 { - // If there are no PKs on the source then we use - // the same PK[E] columns as the target. - td.tablePlan.sourcePkCols = td.tablePlan.pkCols - } - pkColCnt := len(td.tablePlan.sourcePkCols) +func (td *tableDiffer) lastPKFromRow(row []sqltypes.Value, pkCols []int) ([]byte, error) { + pkColCnt := len(pkCols) pkFields := make([]*querypb.Field, pkColCnt) pkVals := make([]sqltypes.Value, pkColCnt) - for i, colIndex := range td.tablePlan.sourcePkCols { + for i, colIndex := range pkCols { pkFields[i] = td.tablePlan.table.Fields[colIndex] pkVals[i] = row[colIndex] } @@ -926,6 +908,49 @@ func (td *tableDiffer) adjustForSourceTimeZone(targetSelectExprs sqlparser.Selec return targetSelectExprs } +// getSourcePKCols populates the sourcePkCols field in the tablePlan. +// We need this information in order to save the lastpk value for the +// source as the PK columns may differ between the source and target. +func (td *tableDiffer) getSourcePKCols() error { + ctx, cancel := context.WithTimeout(td.wd.ct.vde.ctx, topo.RemoteOperationTimeout*3) + defer cancel() + // We use the first sourceShard as all of them should have the same schema. + sourceShardName := maps.Keys(td.wd.ct.sources)[0] + sourceTS, err := td.wd.getSourceTopoServer() + if err != nil { + return vterrors.Wrap(err, "failed to get source topo server") + } + sourceShard, err := sourceTS.GetShard(ctx, td.wd.ct.sourceKeyspace, sourceShardName) + if err != nil { + return err + } + if sourceShard.PrimaryAlias == nil { + return fmt.Errorf("source shard %s has no primary", sourceShardName) + } + sourceTablet, err := sourceTS.GetTablet(ctx, sourceShard.PrimaryAlias) + if err != nil { + return fmt.Errorf("failed to get source shard %s primary", sourceShardName) + } + sourceSchema, err := td.wd.ct.tmc.GetSchema(ctx, sourceTablet.Tablet, &tabletmanagerdatapb.GetSchemaRequest{ + Tables: []string{td.table.Name}, + }) + if err != nil { + return err + } + sourceTable := sourceSchema.TableDefinitions[0] + sourcePKColumns := make(map[string]struct{}, len(sourceTable.PrimaryKeyColumns)) + td.tablePlan.sourcePkCols = make([]int, 0, len(sourceTable.PrimaryKeyColumns)) + for _, pkc := range sourceTable.PrimaryKeyColumns { + sourcePKColumns[pkc] = struct{}{} + } + for i, pkc := range td.table.PrimaryKeyColumns { + if _, ok := sourcePKColumns[pkc]; ok { + td.tablePlan.sourcePkCols = append(td.tablePlan.sourcePkCols, i) + } + } + return nil +} + func getColumnNameForSelectExpr(selectExpression sqlparser.SelectExpr) (string, error) { aliasedExpr := selectExpression.(*sqlparser.AliasedExpr) expr := aliasedExpr.Expr diff --git a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go index 9e0ee6cbc41..34fd3cfad19 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ.go @@ -26,7 +26,6 @@ import ( "strings" "time" - "golang.org/x/exp/maps" "google.golang.org/protobuf/encoding/prototext" "vitess.io/vitess/go/mysql/collations" @@ -49,6 +48,11 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) +const ( + source = "source" + target = "target" +) + // workflowDiffer has metadata and state for the vdiff of a single workflow on this tablet // only one vdiff can be running for a workflow at any time. type workflowDiffer struct { @@ -379,8 +383,8 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl if err != nil { return err } - td.lastSourcePK = lastPK["source"] - td.lastTargetPK = lastPK["target"] + td.lastSourcePK = lastPK[source] + td.lastTargetPK = lastPK[target] wd.tableDiffers[table.Name] = td if _, err := td.buildTablePlan(dbClient, wd.ct.vde.dbName, wd.collationEnv); err != nil { return err @@ -388,42 +392,9 @@ func (wd *workflowDiffer) buildPlan(dbClient binlogplayer.DBClient, filter *binl // We get the PK columns from the source schema as well, as they can // differ and determine the proper lastPK to use when saving progress. - // We use the first sourceShard as all of them should have the same schema. - sourceShardName := maps.Keys(wd.ct.sources)[0] - sourceTS, err := wd.getSourceTopoServer() - if err != nil { - return vterrors.Wrap(err, "failed to get source topo server") - } - sourceShard, err := sourceTS.GetShard(wd.ct.vde.ctx, wd.ct.sourceKeyspace, sourceShardName) - if err != nil { - return err - } - if sourceShard.PrimaryAlias == nil { - return fmt.Errorf("source shard %s has no primary", sourceShardName) - } - sourceTablet, err := sourceTS.GetTablet(wd.ct.vde.ctx, sourceShard.PrimaryAlias) - if err != nil { - return fmt.Errorf("failed to get source shard %s primary", sourceShardName) - } - sourceSchema, err := wd.ct.tmc.GetSchema(wd.ct.vde.ctx, sourceTablet.Tablet, &tabletmanagerdatapb.GetSchemaRequest{ - Tables: []string{table.Name}, - }) - if err != nil { + if err := td.getSourcePKCols(); err != nil { return err } - //log.Errorf("DEBUG: sourceTable.PrimaryKeyColumns: %v", sourceSchema.TableDefinitions[0].PrimaryKeyColumns) - sourcePKColumns := make(map[string]struct{}, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns)) - td.tablePlan.sourcePkCols = make([]int, 0, len(sourceSchema.TableDefinitions[0].PrimaryKeyColumns)) - for _, pkc := range sourceSchema.TableDefinitions[0].PrimaryKeyColumns { - sourcePKColumns[pkc] = struct{}{} - } - //log.Errorf("DEBUG: sourcePKColumns: %v", sourcePKColumns) - for i, pkc := range table.PrimaryKeyColumns { - if _, ok := sourcePKColumns[pkc]; ok { - td.tablePlan.sourcePkCols = append(td.tablePlan.sourcePkCols, i) - } - } - //log.Errorf("DEBUG: td.tablePlan.sourcePkCols: %v", td.tablePlan.sourcePkCols) } if len(wd.tableDiffers) == 0 { return fmt.Errorf("no tables found to diff, %s:%s, on tablet %v", @@ -456,14 +427,12 @@ func (wd *workflowDiffer) getTableLastPK(dbClient binlogplayer.DBClient, tableNa if err := json.Unmarshal(lastpk, &lastPK); err != nil { return nil, vterrors.Wrapf(err, "failed to unmarshal lastpk JSON for table %s", tableName) } - //log.Errorf("DEBUG: getTabletLastPK lastPKBytes: %v", lastPK) for k, v := range lastPK { lastPKResults[k] = &querypb.QueryResult{} if err := prototext.Unmarshal([]byte(v), lastPKResults[k]); err != nil { return nil, vterrors.Wrapf(err, "failed to unmarshal lastpk QueryResult for table %s", tableName) } } - //log.Errorf("DEBUG: getTabletLastPK lastPKRResults: %v", lastPKResults) return lastPKResults, nil } } @@ -549,5 +518,7 @@ func (wd *workflowDiffer) getSourceTopoServer() (*topo.Server, error) { if wd.ct.externalCluster == "" { return wd.ct.ts, nil } - return wd.ct.ts.OpenExternalVitessClusterServer(wd.ct.vde.ctx, wd.ct.externalCluster) + ctx, cancel := context.WithTimeout(wd.ct.vde.ctx, topo.RemoteOperationTimeout) + defer cancel() + return wd.ct.ts.OpenExternalVitessClusterServer(ctx, wd.ct.externalCluster) }