From edb6d7aa3c2cfc92e2073724129da1b16b6498d2 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Tue, 27 Feb 2024 18:40:34 -0800 Subject: [PATCH] [postgres] Add scanning integration test (#211) --- sources/postgres/integration_test/main.go | 177 +++++++++++++++++++--- 1 file changed, 159 insertions(+), 18 deletions(-) diff --git a/sources/postgres/integration_test/main.go b/sources/postgres/integration_test/main.go index 2acaf228..561c2463 100644 --- a/sources/postgres/integration_test/main.go +++ b/sources/postgres/integration_test/main.go @@ -5,9 +5,11 @@ import ( "encoding/json" "fmt" "log/slog" + "maps" "math/rand/v2" "strings" + "github.com/artie-labs/transfer/lib/cdc/util" _ "github.com/jackc/pgx/v5/stdlib" "github.com/artie-labs/reader/config" @@ -16,7 +18,6 @@ import ( "github.com/artie-labs/reader/lib/logger" "github.com/artie-labs/reader/lib/postgres" "github.com/artie-labs/reader/sources/postgres/adapter" - "github.com/artie-labs/transfer/lib/cdc/util" ) var pgConfig = config.PostgreSQL{ @@ -37,14 +38,19 @@ func main() { if err != nil { logger.Fatal("Types test failed", slog.Any("err", err)) } + + err = testScan(db) + if err != nil { + logger.Fatal("Scan test failed", slog.Any("err", err)) + } } -func rawMessageTimestamp(message lib.RawMessage) int64 { +func getPayload(message lib.RawMessage) util.SchemaEventPayload { payloadTyped, ok := message.GetPayload().(util.SchemaEventPayload) if !ok { panic("payload is not of type util.SchemaEventPayload") } - return payloadTyped.Payload.Source.TsMs + return payloadTyped } func checkDifference(name, expected, actual string) bool { @@ -76,10 +82,11 @@ func checkDifference(name, expected, actual string) bool { return true } -func readTable(db *sql.DB, tableName string) ([]lib.RawMessage, error) { +func readTable(db *sql.DB, tableName string, batchSize int) ([]lib.RawMessage, error) { tableCfg := config.PostgreSQLTable{ - Schema: "public", - Name: tableName, + Schema: "public", + Name: tableName, + BatchSize: uint(batchSize), } table := postgres.NewTable(tableCfg.Schema, tableCfg.Name) @@ -103,7 +110,7 @@ func readTable(db *sql.DB, tableName string) ([]lib.RawMessage, error) { return rows, nil } -const createTableQuery = ` +const testTypesCreateTableQuery = ` CREATE TABLE %s ( pk integer PRIMARY KEY NOT NULL, -- All the types from https://www.postgresql.org/docs/current/datatype.html#DATATYPE-TABLE @@ -155,7 +162,7 @@ CREATE TABLE %s ( ) ` -const insertQuery = ` +const testTypesInsertQuery = ` INSERT INTO %s VALUES ( -- pk 1, @@ -250,10 +257,6 @@ INSERT INTO %s VALUES ( ) ` -const expectedPartitionKey = `{ -"pk": 1 -}` - const expectedPayloadTemplate = `{ "schema": { "type": "", @@ -597,10 +600,11 @@ const expectedPayloadTemplate = `{ } }` +// testTypes checks that PostgreSQL data types are handled correctly. func testTypes(db *sql.DB) error { tempTableName := fmt.Sprintf("artie_reader_%d", 10_000+rand.Int32N(5_000)) slog.Info("Creating temporary table...", slog.String("table", tempTableName)) - _, err := db.Exec(fmt.Sprintf(createTableQuery, tempTableName)) + _, err := db.Exec(fmt.Sprintf(testTypesCreateTableQuery, tempTableName)) if err != nil { return fmt.Errorf("unable to create temporary table: %w", err) } @@ -612,12 +616,12 @@ func testTypes(db *sql.DB) error { }() slog.Info("Inserting data...") - _, err = db.Exec(fmt.Sprintf(insertQuery, tempTableName)) + _, err = db.Exec(fmt.Sprintf(testTypesInsertQuery, tempTableName)) if err != nil { return fmt.Errorf("unable to insert data: %w", err) } - rows, err := readTable(db, tempTableName) + rows, err := readTable(db, tempTableName, 100) if err != nil { return err } @@ -627,7 +631,7 @@ func testTypes(db *sql.DB) error { } row := rows[0] - keyBytes, err := json.MarshalIndent(row.PartitionKey, "", "") + keyBytes, err := json.Marshal(row.PartitionKey) if err != nil { return fmt.Errorf("failed to marshal partition key: %w", err) } @@ -637,14 +641,151 @@ func testTypes(db *sql.DB) error { return fmt.Errorf("failed to marshal payload") } - if checkDifference("partition key", expectedPartitionKey, string(keyBytes)) { + if checkDifference("partition key", `{"pk":1}`, string(keyBytes)) { return fmt.Errorf("partition key does not match") } - expectedPayload := fmt.Sprintf(expectedPayloadTemplate, rawMessageTimestamp(row), tempTableName) + expectedPayload := fmt.Sprintf(expectedPayloadTemplate, getPayload(row).Payload.Source.TsMs, tempTableName) if checkDifference("payload", expectedPayload, string(valueBytes)) { return fmt.Errorf("payload does not match") } return nil } + +const testScanCreateTableQuery = ` +CREATE TABLE %s ( + c_int_pk integer NOT NULL, + c_boolean_pk boolean NOT NULL, + c_text_pk text NOT NULL, + c_text_value text, + PRIMARY KEY(c_int_pk, c_boolean_pk, c_text_pk) +) +` + +const testScanInsertQuery = ` +INSERT INTO %s VALUES +(46, false, 'dj', 'row 0'), +(73, false, 'dr', 'row 1'), +(35, false, 'dr', 'row 2'), +(4, false, 'jn', 'row 3'), +(60, true, 'rj', 'row 4'), +(89, true, 'dn', 'row 5'), +(62, false, 'nn', 'row 6'), +(5, false, 'rn', 'row 7'), +(87, false, 'nr', 'row 8'), +(86, false, 'rn', 'row 9'), +(7, true, 'rr', 'row 10'), +(94, false, 'dn', 'row 11'), +(27, false, 'jr', 'row 12'), +(45, true, 'nr', 'row 13'), +(41, true, 'nr', 'row 14'), +(57, false, 'nj', 'row 15'), +(13, true, 'rd', 'row 16'), +(88, true, 'rj', 'row 17'), +(54, true, 'rd', 'row 18'), +(29, false, 'nr', 'row 19'), +(91, false, 'nj', 'row 20'), +(26, false, 'dr', 'row 21'), +(15, false, 'jr', 'row 22'), +(29, false, 'rj', 'row 23'), +(88, false, 'rr', 'row 24') +` + +// testScan checks that we're fetching all the data from PostgreSQL. +func testScan(db *sql.DB) error { + tempTableName := fmt.Sprintf("artie_reader_%d", 10_000+rand.Int32N(5_000)) + slog.Info("Creating temporary table...", slog.String("table", tempTableName)) + _, err := db.Exec(fmt.Sprintf(testScanCreateTableQuery, tempTableName)) + if err != nil { + return fmt.Errorf("unable to create temporary table: %w", err) + } + defer func() { + slog.Info("Dropping temporary table...", slog.String("table", tempTableName)) + if _, err := db.Exec(fmt.Sprintf("DROP TABLE %s", tempTableName)); err != nil { + slog.Error("Failed to drop table", slog.Any("err", err)) + } + }() + + slog.Info("Inserting data...") + _, err = db.Exec(fmt.Sprintf(testScanInsertQuery, tempTableName)) + if err != nil { + return fmt.Errorf("unable to insert data: %w", err) + } + + expectedPartitionKeys := []map[string]any{ + {"c_int_pk": int64(4), "c_boolean_pk": false, "c_text_pk": "jn"}, + {"c_int_pk": int64(5), "c_boolean_pk": false, "c_text_pk": "rn"}, + {"c_int_pk": int64(7), "c_boolean_pk": true, "c_text_pk": "rr"}, + {"c_int_pk": int64(13), "c_boolean_pk": true, "c_text_pk": "rd"}, + {"c_int_pk": int64(15), "c_boolean_pk": false, "c_text_pk": "jr"}, + {"c_int_pk": int64(26), "c_boolean_pk": false, "c_text_pk": "dr"}, + {"c_int_pk": int64(27), "c_boolean_pk": false, "c_text_pk": "jr"}, + {"c_int_pk": int64(29), "c_boolean_pk": false, "c_text_pk": "nr"}, + {"c_int_pk": int64(29), "c_boolean_pk": false, "c_text_pk": "rj"}, + {"c_int_pk": int64(35), "c_boolean_pk": false, "c_text_pk": "dr"}, + {"c_int_pk": int64(41), "c_boolean_pk": true, "c_text_pk": "nr"}, + {"c_int_pk": int64(45), "c_boolean_pk": true, "c_text_pk": "nr"}, + {"c_int_pk": int64(46), "c_boolean_pk": false, "c_text_pk": "dj"}, + {"c_int_pk": int64(54), "c_boolean_pk": true, "c_text_pk": "rd"}, + {"c_int_pk": int64(57), "c_boolean_pk": false, "c_text_pk": "nj"}, + {"c_int_pk": int64(60), "c_boolean_pk": true, "c_text_pk": "rj"}, + {"c_int_pk": int64(62), "c_boolean_pk": false, "c_text_pk": "nn"}, + {"c_int_pk": int64(73), "c_boolean_pk": false, "c_text_pk": "dr"}, + {"c_int_pk": int64(86), "c_boolean_pk": false, "c_text_pk": "rn"}, + {"c_int_pk": int64(87), "c_boolean_pk": false, "c_text_pk": "nr"}, + {"c_int_pk": int64(88), "c_boolean_pk": false, "c_text_pk": "rr"}, + {"c_int_pk": int64(88), "c_boolean_pk": true, "c_text_pk": "rj"}, + {"c_int_pk": int64(89), "c_boolean_pk": true, "c_text_pk": "dn"}, + {"c_int_pk": int64(91), "c_boolean_pk": false, "c_text_pk": "nj"}, + {"c_int_pk": int64(94), "c_boolean_pk": false, "c_text_pk": "dn"}, + } + expectedValues := []string{ + "row 3", + "row 7", + "row 10", + "row 16", + "row 22", + "row 21", + "row 12", + "row 19", + "row 23", + "row 2", + "row 14", + "row 13", + "row 0", + "row 18", + "row 15", + "row 4", + "row 6", + "row 1", + "row 9", + "row 8", + "row 24", + "row 17", + "row 5", + "row 20", + "row 11", + } + + for _, batchSize := range []int{1, 2, 5, 6, 24, 25, 26} { + rows, err := readTable(db, tempTableName, batchSize) + if err != nil { + return err + } + if len(rows) != len(expectedPartitionKeys) { + return fmt.Errorf("expected %d rows, got %d, batch size %d", len(expectedPartitionKeys), len(rows), batchSize) + } + for i, row := range rows { + if !maps.Equal(row.PartitionKey, expectedPartitionKeys[i]) { + return fmt.Errorf("partition keys are different for row %d, batch size %d, %T != %T", i, batchSize, row.PartitionKey, expectedPartitionKeys[i]) + } + textValue := getPayload(row).Payload.After["c_text_value"] + if textValue != expectedValues[i] { + return fmt.Errorf("row values are different for row %d, batch size %d, %T != %T", i, batchSize, textValue, expectedPartitionKeys[i]) + } + } + } + + return nil +}