From bde6be1bfbfc6df7468413fc1cf73ab3f1831230 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 16 Dec 2024 12:17:44 +0100 Subject: [PATCH] feat: fetch last_insert_id when asked for it Signed-off-by: Andres Taylor --- go/vt/vttablet/endtoend/misc_test.go | 10 +++++ go/vt/vttablet/tabletserver/query_executor.go | 39 ++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/go/vt/vttablet/endtoend/misc_test.go b/go/vt/vttablet/endtoend/misc_test.go index e6755e403e2..8d5630a5dd8 100644 --- a/go/vt/vttablet/endtoend/misc_test.go +++ b/go/vt/vttablet/endtoend/misc_test.go @@ -618,6 +618,16 @@ func TestLastInsertId(t *testing.T) { assert.Truef(t, qr.Rows[0][0].Equal(wantCol), "Execute: \n%#v, want \n%#v", qr.Rows[0][0], wantCol) } +func TestSelectLastInsertId(t *testing.T) { + client := framework.NewClient() + rs, err := client.ExecuteWithOptions("select 1 from dual where last_insert_id(42) = 42", nil, &querypb.ExecuteOptions{ + IncludedFields: querypb.ExecuteOptions_ALL, + FetchLastInsertId: true, + }) + require.NoError(t, err) + assert.EqualValues(t, 42, rs.InsertID) +} + func TestAppDebugRequest(t *testing.T) { client := framework.NewClient() diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 519b60b79d6..1acbe2beed4 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -1121,7 +1121,17 @@ func (qre *QueryExecutor) execDBConn(conn *connpool.Conn, sql string, wantfields } defer qre.tsv.statelessql.Remove(qd) - return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) + exec, err := conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) + if err != nil { + return nil, err + } + if qre.options.FetchLastInsertId { + err := qre.fetchLastInsertID(ctx, conn, exec) + if err != nil { + return nil, err + } + } + return exec, nil } func (qre *QueryExecutor) execStatefulConn(conn *StatefulConnection, sql string, wantfields bool) (*sqltypes.Result, error) { @@ -1137,7 +1147,32 @@ func (qre *QueryExecutor) execStatefulConn(conn *StatefulConnection, sql string, } defer qre.tsv.statefulql.Remove(qd) - return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) + exec, err := conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) + if err != nil { + return nil, err + } + if qre.options.FetchLastInsertId { + err = qre.fetchLastInsertID(ctx, conn.UnderlyingDBConn().Conn, exec) + if err != nil { + return nil, err + } + } + return exec, nil +} + +func (qre *QueryExecutor) fetchLastInsertID(ctx context.Context, conn *connpool.Conn, exec *sqltypes.Result) error { + result, err := conn.Exec(ctx, "select last_insert_id()", 1, false) + if err != nil { + return err + } + + cell := result.Rows[0][0] + insertID, err := cell.ToCastUint64() + if err != nil { + return err + } + exec.InsertID = insertID + return nil } func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction bool, sql string, callback func(*sqltypes.Result) error) error {