diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 5e7a5a64334..1cd58b31cc3 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -2780,6 +2780,76 @@ func TestExecutorPrepareExecute(t *testing.T) { require.Error(t, err) } +// TestExecutorRejectTwoPC test all the unsupported cases for multi-shard atomic commit. +func TestExecutorRejectTwoPC(t *testing.T) { + executor, sbc1, sbc2, _, ctx := createExecutorEnv(t) + tcases := []struct { + sqls []string + testRes []*sqltypes.Result + + expErr string + }{ + { + sqls: []string{ + `set time_zone = "+08:00"`, + `insert into user_extra(user_id) values (1)`, + `insert into user_extra(user_id) values (2)`, + `insert into user_extra(user_id) values (3)`, + }, + expErr: "VT12001: unsupported: atomic distributed transaction commit with system settings", + }, { + sqls: []string{ + `update t1 set unq_col = 1 where id = 1`, + `update t1 set unq_col = 1 where id = 3`, + }, + testRes: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|unq_col|unchanged", "int64|int64|int64"), + "1|2|0"), + }, + expErr: "VT12001: unsupported: atomic distributed transaction commit with consistent lookup vindex", + }, { + sqls: []string{ + `savepoint x`, + `insert into user_extra(user_id) values (1)`, + `insert into user_extra(user_id) values (3)`, + }, + testRes: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|unq_col|unchanged", "int64|int64|int64"), + "1|2|0"), + }, + expErr: "VT12001: unsupported: atomic distributed transaction commit with savepoint", + }, + } + + for _, tcase := range tcases { + t.Run(fmt.Sprintf("%v", tcase.sqls), func(t *testing.T) { + sbc1.SetResults(tcase.testRes) + sbc2.SetResults(tcase.testRes) + + // create a new session + session := NewSafeSession(&vtgatepb.Session{ + TargetString: KsTestSharded, + TransactionMode: vtgatepb.TransactionMode_TWOPC, + EnableSystemSettings: true, + }) + + // start transaction + _, err := executor.Execute(ctx, nil, "TestExecutorRejectTwoPC", session, "begin", nil) + require.NoError(t, err) + + // execute queries + for _, sql := range tcase.sqls { + _, err = executor.Execute(ctx, nil, "TestExecutorRejectTwoPC", session, sql, nil) + require.NoError(t, err) + } + + // commit 2pc + _, err = executor.Execute(ctx, nil, "TestExecutorRejectTwoPC", session, "commit", nil) + require.ErrorContains(t, err, tcase.expErr) + }) + } +} + func TestExecutorTruncateErrors(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) diff --git a/go/vt/vtgate/legacy_scatter_conn_test.go b/go/vt/vtgate/legacy_scatter_conn_test.go index 8fefce1dd66..4512fc0724e 100644 --- a/go/vt/vtgate/legacy_scatter_conn_test.go +++ b/go/vt/vtgate/legacy_scatter_conn_test.go @@ -620,6 +620,6 @@ func newTestScatterConn(ctx context.Context, hc discovery.HealthCheck, serv srvt // in '-cells_to_watch' command line parameter, which is // empty by default. So it's unused in this test, set to nil. gw := NewTabletGateway(ctx, hc, serv, cell) - tc := NewTxConn(gw, vtgatepb.TransactionMode_TWOPC) + tc := NewTxConn(gw, vtgatepb.TransactionMode_MULTI) return NewScatterConn("", tc, gw) } diff --git a/go/vt/vtgate/tx_conn.go b/go/vt/vtgate/tx_conn.go index 05c47e64b6c..372c3fc6164 100644 --- a/go/vt/vtgate/tx_conn.go +++ b/go/vt/vtgate/tx_conn.go @@ -187,16 +187,16 @@ func (txc *TxConn) commitNormal(ctx context.Context, session *SafeSession) error // commit2PC will not used the pinned tablets - to make sure we use the current source, we need to use the gateway's queryservice func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) (err error) { - if len(session.PreSessions) != 0 || len(session.PostSessions) != 0 { - _ = txc.Rollback(ctx, session) - return vterrors.New(vtrpcpb.Code_FAILED_PRECONDITION, "pre or post actions not allowed for 2PC commits") - } - // If the number of participants is one or less, then it's a normal commit. if len(session.ShardSessions) <= 1 { return txc.commitNormal(ctx, session) } + if err := txc.checkValidCondition(session); err != nil { + _ = txc.Rollback(ctx, session) + return err + } + mmShard := session.ShardSessions[0] rmShards := session.ShardSessions[1:] dtid := dtids.New(mmShard) @@ -276,6 +276,19 @@ func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) (err err return nil } +func (txc *TxConn) checkValidCondition(session *SafeSession) error { + if len(session.PreSessions) != 0 || len(session.PostSessions) != 0 { + return vterrors.VT12001("atomic distributed transaction commit with consistent lookup vindex") + } + if len(session.GetSavepoints()) != 0 { + return vterrors.VT12001("atomic distributed transaction commit with savepoint") + } + if session.GetInReservedConn() { + return vterrors.VT12001("atomic distributed transaction commit with system settings") + } + return nil +} + func (txc *TxConn) errActionAndLogWarn(ctx context.Context, session *SafeSession, txPhase commitPhase, dtid string, mmShard *vtgatepb.Session_ShardSession, rmShards []*vtgatepb.Session_ShardSession) { switch txPhase { case Commit2pcCreateTransaction: