diff --git a/go/vt/vttablet/tabletmanager/rpc_vreplication.go b/go/vt/vttablet/tabletmanager/rpc_vreplication.go index a274da98fdf..c8c334d896e 100644 --- a/go/vt/vttablet/tabletmanager/rpc_vreplication.go +++ b/go/vt/vttablet/tabletmanager/rpc_vreplication.go @@ -59,6 +59,8 @@ const ( // Update field values for multiple workflows. The final format specifier is // used to optionally add any additional predicates to the query. sqlUpdateVReplicationWorkflows = "update /*vt+ ALLOW_UNSAFE_VREPLICATION_WRITE */ %s.vreplication set%s where db_name = '%s'%s" + // Check if workflow is still copying. + sqlGetVReplicationCopyStatus = "select distinct vrepl_id from %s.copy_state where vrepl_id = %d" ) var ( @@ -379,6 +381,18 @@ func (tm *TabletManager) ReadVReplicationWorkflow(ctx context.Context, req *tabl return resp, nil } +func isStreamCopying(tm *TabletManager, id int64) (bool, error) { + query := fmt.Sprintf(sqlGetVReplicationCopyStatus, sidecar.GetIdentifier(), id) + res, err := tm.VREngine.Exec(query) + if err != nil { + return false, err + } + if res != nil && len(res.Rows) > 0 { + return true, nil + } + return false, nil +} + // UpdateVReplicationWorkflow updates the sidecar databases's vreplication // record(s) for this tablet's vreplication workflow stream(s). If there // are no streams for the given workflow on the tablet then a nil result @@ -457,6 +471,17 @@ func (tm *TabletManager) UpdateVReplicationWorkflow(ctx context.Context, req *ta if !textutil.ValueIsSimulatedNull(req.State) { state = binlogdatapb.VReplicationWorkflowState_name[int32(req.State)] } + if state == binlogdatapb.VReplicationWorkflowState_Running.String() { + // `Workflow Start` sets the new state to Running. However, if stream is still copying tables, we should set + // the state as Copying. + isCopying, err := isStreamCopying(tm, id) + if err != nil { + return nil, err + } + if isCopying { + state = binlogdatapb.VReplicationWorkflowState_Copying.String() + } + } bindVars = map[string]*querypb.BindVariable{ "st": sqltypes.StringBindVariable(state), "sc": sqltypes.StringBindVariable(string(source)), diff --git a/go/vt/vttablet/tabletmanager/rpc_vreplication_test.go b/go/vt/vttablet/tabletmanager/rpc_vreplication_test.go index 789319a2a53..f1211c87c8f 100644 --- a/go/vt/vttablet/tabletmanager/rpc_vreplication_test.go +++ b/go/vt/vttablet/tabletmanager/rpc_vreplication_test.go @@ -390,7 +390,8 @@ func TestMoveTables(t *testing.T) { for _, ftc := range targetShards { addInvariants(ftc.vrdbClient, vreplID, sourceTabletUID, position, wf, tenv.cells[0]) - + getCopyStateQuery := fmt.Sprintf(sqlGetVReplicationCopyStatus, sidecar.GetIdentifier(), vreplID) + ftc.vrdbClient.AddInvariant(getCopyStateQuery, &sqltypes.Result{}) tenv.tmc.setVReplicationExecResults(ftc.tablet, getCopyState, &sqltypes.Result{}) ftc.vrdbClient.ExpectRequest(fmt.Sprintf(readAllWorkflows, tenv.dbName, ""), &sqltypes.Result{}, nil) insert := fmt.Sprintf(`%s values ('%s', 'keyspace:\"%s\" shard:\"%s\" filter:{rules:{match:\"t1\" filter:\"select * from t1 where in_keyrange(id, \'%s.hash\', \'%s\')\"}}', '', 0, 0, '%s', 'primary,replica,rdonly', now(), 0, 'Stopped', '%s', %d, 0, 0, '{}')`, @@ -574,6 +575,7 @@ func TestUpdateVReplicationWorkflow(t *testing.T) { ), fmt.Sprintf("%d|%s|%s|%s|Running|", vreplID, blsStr, cells[0], tabletTypes[0]), ) + idQuery, err := sqlparser.ParseAndBind("select id from _vt.vreplication where id = %a", sqltypes.Int64BindVariable(int64(vreplID))) require.NoError(t, err) @@ -585,10 +587,19 @@ func TestUpdateVReplicationWorkflow(t *testing.T) { fmt.Sprintf("%d", vreplID), ) + getCopyStateQuery := fmt.Sprintf(sqlGetVReplicationCopyStatus, sidecar.GetIdentifier(), int64(vreplID)) + copyStatusFields := sqltypes.MakeTestFields( + "id", + "int64", + ) + notCopying := sqltypes.MakeTestResult(copyStatusFields) + copying := sqltypes.MakeTestResult(copyStatusFields, "1") + tests := []struct { - name string - request *tabletmanagerdatapb.UpdateVReplicationWorkflowRequest - query string + name string + request *tabletmanagerdatapb.UpdateVReplicationWorkflowRequest + query string + isCopying bool }{ { name: "update cells", @@ -668,6 +679,19 @@ func TestUpdateVReplicationWorkflow(t *testing.T) { query: fmt.Sprintf(`update _vt.vreplication set state = '%s', source = 'keyspace:\"%s\" shard:\"%s\" filter:{rules:{match:\"corder\" filter:\"select * from corder\"} rules:{match:\"customer\" filter:\"select * from customer\"}}', cell = '%s', tablet_types = '%s' where id in (%d)`, binlogdatapb.VReplicationWorkflowState_Stopped.String(), keyspace, shard, cells[0], tabletTypes[0], vreplID), }, + { + name: "update to running while copying", + request: &tabletmanagerdatapb.UpdateVReplicationWorkflowRequest{ + Workflow: workflow, + State: binlogdatapb.VReplicationWorkflowState_Running, + Cells: textutil.SimulatedNullStringSlice, + TabletTypes: []topodatapb.TabletType{topodatapb.TabletType(textutil.SimulatedNullInt)}, + OnDdl: binlogdatapb.OnDDLAction(textutil.SimulatedNullInt), + }, + isCopying: true, + query: fmt.Sprintf(`update _vt.vreplication set state = 'Copying', source = 'keyspace:\"%s\" shard:\"%s\" filter:{rules:{match:\"corder\" filter:\"select * from corder\"} rules:{match:\"customer\" filter:\"select * from customer\"}}', cell = '%s', tablet_types = '%s' where id in (%d)`, + keyspace, shard, cells[0], tabletTypes[0], vreplID), + }, } for _, tt := range tests { @@ -686,12 +710,21 @@ func TestUpdateVReplicationWorkflow(t *testing.T) { // These are the same for each RPC call. tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(fmt.Sprintf("use %s", sidecar.GetIdentifier()), &sqltypes.Result{}, nil) tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(selectQuery, selectRes, nil) - tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(fmt.Sprintf("use %s", sidecar.GetIdentifier()), &sqltypes.Result{}, nil) - tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(idQuery, idRes, nil) + if tt.request.State == binlogdatapb.VReplicationWorkflowState_Running || + tt.request.State == binlogdatapb.VReplicationWorkflowState(textutil.SimulatedNullInt) { + tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(fmt.Sprintf("use %s", sidecar.GetIdentifier()), &sqltypes.Result{}, nil) + if tt.isCopying { + tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(getCopyStateQuery, copying, nil) + } else { + tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(getCopyStateQuery, notCopying, nil) + } + } // This is our expected query, which will also short circuit // the test with an error as at this point we've tested what // we wanted to test. + tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(fmt.Sprintf("use %s", sidecar.GetIdentifier()), &sqltypes.Result{}, nil) + tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(idQuery, idRes, nil) tenv.tmc.tablets[tabletUID].vrdbClient.ExpectRequest(tt.query, &sqltypes.Result{RowsAffected: 1}, errShortCircuit) _, err = tenv.tmc.tablets[tabletUID].tm.UpdateVReplicationWorkflow(ctx, tt.request) tenv.tmc.tablets[tabletUID].vrdbClient.Wait()