Skip to content

Commit

Permalink
Merge pull request #741 from dolthub/fulghum/copy-from-stdin-tx
Browse files Browse the repository at this point in the history
Support automatic transaction management with `COPY FROM STDIN`
  • Loading branch information
fulghum authored Sep 25, 2024
2 parents 9f1246e + 24c7bce commit 9c38280
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 4 deletions.
12 changes: 11 additions & 1 deletion server/connection_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,18 @@ type ConvertedQuery struct {
// this statement is processed, the server accepts COPY DATA messages from the client with chunks of data to load
// into a table.
type copyFromStdinState struct {
// copyFromStdinNode stores the original CopyFrom statement that initiated the CopyData message sequence. This
// node is used to look at what parameters were specified, such as which table to load data into, file format,
// delimiters, etc.
copyFromStdinNode *node.CopyFrom
dataLoader dataloader.DataLoader
// dataLoader is the implementation of DataLoader that is used to load each individual CopyData chunk into the
// target table.
dataLoader dataloader.DataLoader
// copyErr stores any error that was returned while processing a CopyData message and loading a chunk of data
// to the target table. The server needs to keep track of any errors that were encountered while processing chunks
// so that it can avoid sending a CommandComplete message if an error was encountered after the client already
// sent a CopyDone message to the server.
copyErr error
}

type PortalData struct {
Expand Down
55 changes: 53 additions & 2 deletions server/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"strings"
"sync/atomic"

"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/mysql"
Expand Down Expand Up @@ -616,15 +617,29 @@ func makeCommandComplete(tag string, rows int32) *pgproto3.CommandComplete {
// messages are expected, and the server should tell the client that it is ready for the next query, and |err| contains
// any error that occurred while processing the COPY DATA message.
func (h *ConnectionHandler) handleCopyData(message *pgproto3.CopyData) (stop bool, endOfMessages bool, err error) {
helper, messages, err := h.handleCopyDataHelper(message)
if err != nil {
h.copyFromStdinState.copyErr = err
}
return helper, messages, err
}

// handleCopyDataHelper is a helper function that should only be invoked by handleCopyData. handleCopyData wraps this
// function so that it can capture any returned error message and store it in the saved state.
func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (stop bool, endOfMessages bool, err error) {
if h.copyFromStdinState == nil {
return false, true, fmt.Errorf("COPY DATA message received without a COPY FROM STDIN operation in progress")
}

// Grab a sql.Context
// Grab a sql.Context and ensure the session has a transaction started, otherwise the copied data
// won't get committed correctly.
sqlCtx, err := h.doltgresHandler.NewContext(context.Background(), h.mysqlConn, "")
if err != nil {
return false, false, err
}
if err = startTransaction(sqlCtx); err != nil {
return false, false, err
}

dataLoader := h.copyFromStdinState.dataLoader
if dataLoader == nil {
Expand Down Expand Up @@ -686,6 +701,14 @@ func (h *ConnectionHandler) handleCopyDone(_ *pgproto3.CopyDone) (stop bool, end
fmt.Errorf("COPY DONE message received without a COPY FROM STDIN operation in progress")
}

// If there was a previous error returned from processing a CopyData message, then don't return an error here
// and don't send endOfMessage=true, since the CopyData error already sent endOfMessage=true. If we do send
// endOfMessage=true here, then the client gets confused about the unexpected/extra Idle message since the
// server has already reported it was idle in the last message after the returned error.
if h.copyFromStdinState.copyErr != nil {
return false, false, nil
}

dataLoader := h.copyFromStdinState.dataLoader
if dataLoader == nil {
return false, true,
Expand All @@ -702,6 +725,17 @@ func (h *ConnectionHandler) handleCopyDone(_ *pgproto3.CopyDone) (stop bool, end
return false, false, err
}

// If we aren't in an explicit/user managed transaction, we need to commit the transaction
if !sqlCtx.GetIgnoreAutoCommit() {
txSession, ok := sqlCtx.Session.(sql.TransactionSession)
if !ok {
return false, false, fmt.Errorf("session does not implement sql.TransactionSession")
}
if err = txSession.CommitTransaction(sqlCtx, txSession.GetTransaction()); err != nil {
return false, false, err
}
}

h.copyFromStdinState = nil
// We send back endOfMessage=true, since the COPY DONE message ends the COPY DATA flow and the server is ready
// to accept the next query now.
Expand All @@ -710,7 +744,7 @@ func (h *ConnectionHandler) handleCopyDone(_ *pgproto3.CopyDone) (stop bool, end
})
}

// handleCopyDone handles a COPY FAIL message by aborting the in-progress COPY DATA operation. The |stop| response
// handleCopyFail handles a COPY FAIL message by aborting the in-progress COPY DATA operation. The |stop| response
// parameter is true if the connection handler should shut down the connection, |endOfMessages| is true if no more
// COPY DATA messages are expected, and the server should tell the client that it is ready for the next query, and
// |err| contains any error that occurred while processing the COPY DATA message.
Expand All @@ -732,6 +766,23 @@ func (h *ConnectionHandler) handleCopyFail(_ *pgproto3.CopyFail) (stop bool, end
return false, true, nil
}

// startTransaction checks to see if the current session has a transaction started yet or not, and if not,
// creates a read/write transaction for the session to use. This is necessary for handling commands that alter
// data without going through the GMS engine.
func startTransaction(ctx *sql.Context) error {
doltSession, ok := ctx.Session.(*dsess.DoltSession)
if !ok {
return fmt.Errorf("unexpected session type: %T", ctx.Session)
}
if doltSession.GetTransaction() == nil {
if _, err := doltSession.StartTransaction(ctx, sql.ReadWrite); err != nil {
return err
}
}

return nil
}

func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedQuery, conn net.Conn) error {
_, ok := preparedStatements[name]
if !ok {
Expand Down
34 changes: 33 additions & 1 deletion testing/bats/dataloading.bats
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ teardown() {
[[ "$output" =~ "3 | 03 | 97302 | Guyane" ]] || false
}

# Tests that we can load tabular data dump files that do not explicitly manage the session's transaction.
@test 'dataloading: tabular import, no explicit tx management' {
# Import the data dump and assert the expected output
run query_server -f $BATS_TEST_DIRNAME/dataloading/tab-load-with-no-tx-control.sql
[ "$status" -eq 0 ]
[[ "$output" =~ "COPY 3" ]] || false
[[ ! "$output" =~ "ERROR" ]] || false

# Check the inserted rows
run query_server -c "SELECT * FROM test_info ORDER BY id;"
[ "$status" -eq 0 ]
[[ "$output" =~ "4 | string for 4 | 1" ]] || false
[[ "$output" =~ "5 | string for 5 | 0" ]] || false
[[ "$output" =~ "6 | string for 6 | 0" ]] || false
}

# Tests loading in data via different CSV data files.
@test 'dataloading: csv import' {
# Import the data dump and assert the expected output
Expand Down Expand Up @@ -157,4 +173,20 @@ teardown() {
run query_server -c "SELECT count(*) from tbl1;"
[ "$status" -eq 0 ]
[[ "$output" =~ "100" ]] || false
}
}

# Tests that we can load CSV data dump files that do not explicitly manage the session's transaction.
@test 'dataloading: csv import, no explicit tx management' {
# Import the data dump and assert the expected output
run query_server -f $BATS_TEST_DIRNAME/dataloading/csv-load-with-no-tx-control.sql
[ "$status" -eq 0 ]
[[ "$output" =~ "COPY 3" ]] || false
[[ ! "$output" =~ "ERROR" ]] || false

# Check the inserted rows
run query_server -c "SELECT * FROM test_info ORDER BY id;"
[ "$status" -eq 0 ]
[[ "$output" =~ "4 | string for 4 | 1" ]] || false
[[ "$output" =~ "5 | string for 5 | 0" ]] || false
[[ "$output" =~ "6 | string for 6 | 0" ]] || false
}
11 changes: 11 additions & 0 deletions testing/bats/dataloading/csv-load-with-no-tx-control.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE test (pk int primary key);
INSERT INTO test VALUES (0), (1);

CREATE TABLE test_info (id int, info varchar(255), test_pk int, primary key(id), foreign key (test_pk) references test(pk));

COPY test_info FROM STDIN (FORMAT CSV, HEADER TRUE);
id,info,test_pk
4,string for 4,1
5,string for 5,0
6,string for 6,0
\.
11 changes: 11 additions & 0 deletions testing/bats/dataloading/tab-load-with-no-tx-control.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE test (pk int primary key);
INSERT INTO test VALUES (0), (1);

CREATE TABLE test_info (id int, info varchar(255), test_pk int, primary key(id), foreign key (test_pk) references test(pk));

COPY test_info FROM STDIN WITH (HEADER);
id info test_pk
4 string for 4 1
5 string for 5 0
6 string for 6 0
\.
8 changes: 8 additions & 0 deletions testing/go/regression/tests/triggers.sql
Original file line number Diff line number Diff line change
Expand Up @@ -362,20 +362,27 @@ CREATE TRIGGER insert_when BEFORE INSERT ON main_table
FOR EACH STATEMENT WHEN (true) EXECUTE PROCEDURE trigger_func('insert_when');
CREATE TRIGGER delete_when AFTER DELETE ON main_table
FOR EACH STATEMENT WHEN (true) EXECUTE PROCEDURE trigger_func('delete_when');

SELECT trigger_name, event_manipulation, event_object_schema, event_object_table,
action_order, action_condition, action_orientation, action_timing,
action_reference_old_table, action_reference_new_table
FROM information_schema.triggers
WHERE event_object_table IN ('main_table')
ORDER BY trigger_name COLLATE "C", 2;

INSERT INTO main_table (a) VALUES (123), (456);

COPY main_table FROM stdin;
123 999
456 999
\.

DELETE FROM main_table WHERE a IN (123, 456);

UPDATE main_table SET a = 50, b = 60;

SELECT * FROM main_table ORDER BY a, b;

SELECT pg_get_triggerdef(oid, true) FROM pg_trigger WHERE tgrelid = 'main_table'::regclass AND tgname = 'modified_a';
SELECT pg_get_triggerdef(oid, false) FROM pg_trigger WHERE tgrelid = 'main_table'::regclass AND tgname = 'modified_a';
SELECT pg_get_triggerdef(oid, true) FROM pg_trigger WHERE tgrelid = 'main_table'::regclass AND tgname = 'modified_any';
Expand Down Expand Up @@ -420,6 +427,7 @@ FOR EACH STATEMENT EXECUTE PROCEDURE trigger_func('after_upd_b_stmt');
SELECT pg_get_triggerdef(oid) FROM pg_trigger WHERE tgrelid = 'main_table'::regclass AND tgname = 'after_upd_a_b_row_trig';

UPDATE main_table SET a = 50;

UPDATE main_table SET b = 10;

--
Expand Down

0 comments on commit 9c38280

Please sign in to comment.