Skip to content

Commit

Permalink
[postgres] Add scanning integration test (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Feb 28, 2024
1 parent 2726419 commit edb6d7a
Showing 1 changed file with 159 additions and 18 deletions.
177 changes: 159 additions & 18 deletions sources/postgres/integration_test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"encoding/json"
"fmt"
"log/slog"
"maps"
"math/rand/v2"
"strings"

"github.com/artie-labs/transfer/lib/cdc/util"
_ "github.com/jackc/pgx/v5/stdlib"

"github.com/artie-labs/reader/config"
Expand All @@ -16,7 +18,6 @@ import (
"github.com/artie-labs/reader/lib/logger"
"github.com/artie-labs/reader/lib/postgres"
"github.com/artie-labs/reader/sources/postgres/adapter"
"github.com/artie-labs/transfer/lib/cdc/util"
)

var pgConfig = config.PostgreSQL{
Expand All @@ -37,14 +38,19 @@ func main() {
if err != nil {
logger.Fatal("Types test failed", slog.Any("err", err))
}

err = testScan(db)
if err != nil {
logger.Fatal("Scan test failed", slog.Any("err", err))
}
}

func rawMessageTimestamp(message lib.RawMessage) int64 {
func getPayload(message lib.RawMessage) util.SchemaEventPayload {
payloadTyped, ok := message.GetPayload().(util.SchemaEventPayload)
if !ok {
panic("payload is not of type util.SchemaEventPayload")
}
return payloadTyped.Payload.Source.TsMs
return payloadTyped
}

func checkDifference(name, expected, actual string) bool {
Expand Down Expand Up @@ -76,10 +82,11 @@ func checkDifference(name, expected, actual string) bool {
return true
}

func readTable(db *sql.DB, tableName string) ([]lib.RawMessage, error) {
func readTable(db *sql.DB, tableName string, batchSize int) ([]lib.RawMessage, error) {
tableCfg := config.PostgreSQLTable{
Schema: "public",
Name: tableName,
Schema: "public",
Name: tableName,
BatchSize: uint(batchSize),
}

table := postgres.NewTable(tableCfg.Schema, tableCfg.Name)
Expand All @@ -103,7 +110,7 @@ func readTable(db *sql.DB, tableName string) ([]lib.RawMessage, error) {
return rows, nil
}

const createTableQuery = `
const testTypesCreateTableQuery = `
CREATE TABLE %s (
pk integer PRIMARY KEY NOT NULL,
-- All the types from https://www.postgresql.org/docs/current/datatype.html#DATATYPE-TABLE
Expand Down Expand Up @@ -155,7 +162,7 @@ CREATE TABLE %s (
)
`

const insertQuery = `
const testTypesInsertQuery = `
INSERT INTO %s VALUES (
-- pk
1,
Expand Down Expand Up @@ -250,10 +257,6 @@ INSERT INTO %s VALUES (
)
`

const expectedPartitionKey = `{
"pk": 1
}`

const expectedPayloadTemplate = `{
"schema": {
"type": "",
Expand Down Expand Up @@ -597,10 +600,11 @@ const expectedPayloadTemplate = `{
}
}`

// testTypes checks that PostgreSQL data types are handled correctly.
func testTypes(db *sql.DB) error {
tempTableName := fmt.Sprintf("artie_reader_%d", 10_000+rand.Int32N(5_000))
slog.Info("Creating temporary table...", slog.String("table", tempTableName))
_, err := db.Exec(fmt.Sprintf(createTableQuery, tempTableName))
_, err := db.Exec(fmt.Sprintf(testTypesCreateTableQuery, tempTableName))
if err != nil {
return fmt.Errorf("unable to create temporary table: %w", err)
}
Expand All @@ -612,12 +616,12 @@ func testTypes(db *sql.DB) error {
}()

slog.Info("Inserting data...")
_, err = db.Exec(fmt.Sprintf(insertQuery, tempTableName))
_, err = db.Exec(fmt.Sprintf(testTypesInsertQuery, tempTableName))
if err != nil {
return fmt.Errorf("unable to insert data: %w", err)
}

rows, err := readTable(db, tempTableName)
rows, err := readTable(db, tempTableName, 100)
if err != nil {
return err
}
Expand All @@ -627,7 +631,7 @@ func testTypes(db *sql.DB) error {
}
row := rows[0]

keyBytes, err := json.MarshalIndent(row.PartitionKey, "", "")
keyBytes, err := json.Marshal(row.PartitionKey)
if err != nil {
return fmt.Errorf("failed to marshal partition key: %w", err)
}
Expand All @@ -637,14 +641,151 @@ func testTypes(db *sql.DB) error {
return fmt.Errorf("failed to marshal payload")
}

if checkDifference("partition key", expectedPartitionKey, string(keyBytes)) {
if checkDifference("partition key", `{"pk":1}`, string(keyBytes)) {
return fmt.Errorf("partition key does not match")
}

expectedPayload := fmt.Sprintf(expectedPayloadTemplate, rawMessageTimestamp(row), tempTableName)
expectedPayload := fmt.Sprintf(expectedPayloadTemplate, getPayload(row).Payload.Source.TsMs, tempTableName)
if checkDifference("payload", expectedPayload, string(valueBytes)) {
return fmt.Errorf("payload does not match")
}

return nil
}

const testScanCreateTableQuery = `
CREATE TABLE %s (
c_int_pk integer NOT NULL,
c_boolean_pk boolean NOT NULL,
c_text_pk text NOT NULL,
c_text_value text,
PRIMARY KEY(c_int_pk, c_boolean_pk, c_text_pk)
)
`

const testScanInsertQuery = `
INSERT INTO %s VALUES
(46, false, 'dj', 'row 0'),
(73, false, 'dr', 'row 1'),
(35, false, 'dr', 'row 2'),
(4, false, 'jn', 'row 3'),
(60, true, 'rj', 'row 4'),
(89, true, 'dn', 'row 5'),
(62, false, 'nn', 'row 6'),
(5, false, 'rn', 'row 7'),
(87, false, 'nr', 'row 8'),
(86, false, 'rn', 'row 9'),
(7, true, 'rr', 'row 10'),
(94, false, 'dn', 'row 11'),
(27, false, 'jr', 'row 12'),
(45, true, 'nr', 'row 13'),
(41, true, 'nr', 'row 14'),
(57, false, 'nj', 'row 15'),
(13, true, 'rd', 'row 16'),
(88, true, 'rj', 'row 17'),
(54, true, 'rd', 'row 18'),
(29, false, 'nr', 'row 19'),
(91, false, 'nj', 'row 20'),
(26, false, 'dr', 'row 21'),
(15, false, 'jr', 'row 22'),
(29, false, 'rj', 'row 23'),
(88, false, 'rr', 'row 24')
`

// testScan checks that we're fetching all the data from PostgreSQL.
func testScan(db *sql.DB) error {
tempTableName := fmt.Sprintf("artie_reader_%d", 10_000+rand.Int32N(5_000))
slog.Info("Creating temporary table...", slog.String("table", tempTableName))
_, err := db.Exec(fmt.Sprintf(testScanCreateTableQuery, tempTableName))
if err != nil {
return fmt.Errorf("unable to create temporary table: %w", err)
}
defer func() {
slog.Info("Dropping temporary table...", slog.String("table", tempTableName))
if _, err := db.Exec(fmt.Sprintf("DROP TABLE %s", tempTableName)); err != nil {
slog.Error("Failed to drop table", slog.Any("err", err))
}
}()

slog.Info("Inserting data...")
_, err = db.Exec(fmt.Sprintf(testScanInsertQuery, tempTableName))
if err != nil {
return fmt.Errorf("unable to insert data: %w", err)
}

expectedPartitionKeys := []map[string]any{
{"c_int_pk": int64(4), "c_boolean_pk": false, "c_text_pk": "jn"},
{"c_int_pk": int64(5), "c_boolean_pk": false, "c_text_pk": "rn"},
{"c_int_pk": int64(7), "c_boolean_pk": true, "c_text_pk": "rr"},
{"c_int_pk": int64(13), "c_boolean_pk": true, "c_text_pk": "rd"},
{"c_int_pk": int64(15), "c_boolean_pk": false, "c_text_pk": "jr"},
{"c_int_pk": int64(26), "c_boolean_pk": false, "c_text_pk": "dr"},
{"c_int_pk": int64(27), "c_boolean_pk": false, "c_text_pk": "jr"},
{"c_int_pk": int64(29), "c_boolean_pk": false, "c_text_pk": "nr"},
{"c_int_pk": int64(29), "c_boolean_pk": false, "c_text_pk": "rj"},
{"c_int_pk": int64(35), "c_boolean_pk": false, "c_text_pk": "dr"},
{"c_int_pk": int64(41), "c_boolean_pk": true, "c_text_pk": "nr"},
{"c_int_pk": int64(45), "c_boolean_pk": true, "c_text_pk": "nr"},
{"c_int_pk": int64(46), "c_boolean_pk": false, "c_text_pk": "dj"},
{"c_int_pk": int64(54), "c_boolean_pk": true, "c_text_pk": "rd"},
{"c_int_pk": int64(57), "c_boolean_pk": false, "c_text_pk": "nj"},
{"c_int_pk": int64(60), "c_boolean_pk": true, "c_text_pk": "rj"},
{"c_int_pk": int64(62), "c_boolean_pk": false, "c_text_pk": "nn"},
{"c_int_pk": int64(73), "c_boolean_pk": false, "c_text_pk": "dr"},
{"c_int_pk": int64(86), "c_boolean_pk": false, "c_text_pk": "rn"},
{"c_int_pk": int64(87), "c_boolean_pk": false, "c_text_pk": "nr"},
{"c_int_pk": int64(88), "c_boolean_pk": false, "c_text_pk": "rr"},
{"c_int_pk": int64(88), "c_boolean_pk": true, "c_text_pk": "rj"},
{"c_int_pk": int64(89), "c_boolean_pk": true, "c_text_pk": "dn"},
{"c_int_pk": int64(91), "c_boolean_pk": false, "c_text_pk": "nj"},
{"c_int_pk": int64(94), "c_boolean_pk": false, "c_text_pk": "dn"},
}
expectedValues := []string{
"row 3",
"row 7",
"row 10",
"row 16",
"row 22",
"row 21",
"row 12",
"row 19",
"row 23",
"row 2",
"row 14",
"row 13",
"row 0",
"row 18",
"row 15",
"row 4",
"row 6",
"row 1",
"row 9",
"row 8",
"row 24",
"row 17",
"row 5",
"row 20",
"row 11",
}

for _, batchSize := range []int{1, 2, 5, 6, 24, 25, 26} {
rows, err := readTable(db, tempTableName, batchSize)
if err != nil {
return err
}
if len(rows) != len(expectedPartitionKeys) {
return fmt.Errorf("expected %d rows, got %d, batch size %d", len(expectedPartitionKeys), len(rows), batchSize)
}
for i, row := range rows {
if !maps.Equal(row.PartitionKey, expectedPartitionKeys[i]) {
return fmt.Errorf("partition keys are different for row %d, batch size %d, %T != %T", i, batchSize, row.PartitionKey, expectedPartitionKeys[i])
}
textValue := getPayload(row).Payload.After["c_text_value"]
if textValue != expectedValues[i] {
return fmt.Errorf("row values are different for row %d, batch size %d, %T != %T", i, batchSize, textValue, expectedPartitionKeys[i])
}
}
}

return nil
}

0 comments on commit edb6d7a

Please sign in to comment.