Skip to content

Commit

Permalink
[postgres] Move Table loading into adapter (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Mar 6, 2024
1 parent e4540a5 commit 182e8c7
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 27 deletions.
9 changes: 4 additions & 5 deletions integration_tests/postgres/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/artie-labs/reader/lib"
"github.com/artie-labs/reader/lib/debezium"
"github.com/artie-labs/reader/lib/logger"
"github.com/artie-labs/reader/lib/postgres"
"github.com/artie-labs/reader/lib/rdbms"
"github.com/artie-labs/reader/sources/postgres/adapter"
)
Expand Down Expand Up @@ -64,16 +63,16 @@ func readTable(db *sql.DB, tableName string, batchSize int) ([]lib.RawMessage, e
BatchSize: uint(batchSize),
}

table, err := postgres.LoadTable(db, tableCfg.Schema, tableCfg.Name)
dbzAdapter, err := adapter.NewPostgresAdapter(db, tableCfg)
if err != nil {
return nil, fmt.Errorf("unable to load table metadata: %w", err)
return nil, err
}

scanner, err := postgres.NewScanner(db, table, tableCfg.ToScannerConfig(1))
scanner, err := dbzAdapter.NewIterator()
if err != nil {
return nil, fmt.Errorf("failed to build scanner: %w", err)
}
dbzTransformer := debezium.NewDebeziumTransformer(adapter.NewPostgresAdapter(*table), &scanner)
dbzTransformer := debezium.NewDebeziumTransformer(dbzAdapter, &scanner)
rows := []lib.RawMessage{}
for dbzTransformer.HasNext() {
batch, err := dbzTransformer.Next()
Expand Down
2 changes: 1 addition & 1 deletion lib/postgres/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func convertToStringForQuery(value any, dataType schema.DataType) (string, error
}
}

func NewScanner(db *sql.DB, table *Table, cfg scan.ScannerConfig) (scan.Scanner, error) {
func NewScanner(db *sql.DB, table Table, cfg scan.ScannerConfig) (scan.Scanner, error) {
primaryKeyBounds, err := table.GetPrimaryKeysBounds(db)
if err != nil {
return scan.Scanner{}, err
Expand Down
28 changes: 25 additions & 3 deletions sources/postgres/adapter/adapter.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
package adapter

import (
"database/sql"
"fmt"
"log/slog"
"strings"

"github.com/artie-labs/transfer/lib/debezium"

"github.com/artie-labs/reader/config"
"github.com/artie-labs/reader/lib/postgres"
"github.com/artie-labs/reader/lib/rdbms/scan"
)

const defaultErrorRetries = 10

type postgresAdapter struct {
table postgres.Table
db *sql.DB
table postgres.Table
scannerCfg scan.ScannerConfig
}

func NewPostgresAdapter(table postgres.Table) postgresAdapter {
return postgresAdapter{table: table}
func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (postgresAdapter, error) {
slog.Info("Loading metadata for table")
table, err := postgres.LoadTable(db, tableCfg.Schema, tableCfg.Name)
if err != nil {
return postgresAdapter{}, fmt.Errorf("failed to load metadata for table %s.%s: %w", tableCfg.Schema, tableCfg.Name, err)
}

return postgresAdapter{
db: db,
table: *table,
scannerCfg: tableCfg.ToScannerConfig(defaultErrorRetries),
}, nil
}

func (p postgresAdapter) TableName() string {
Expand All @@ -33,6 +51,10 @@ func (p postgresAdapter) Fields() []debezium.Field {
return fields
}

func (p postgresAdapter) NewIterator() (scan.Scanner, error) {
return postgres.NewScanner(p.db, p.table, p.scannerCfg)
}

// PartitionKey returns a map of primary keys and their values for a given row.
func (p postgresAdapter) PartitionKey(row map[string]any) map[string]any {
result := make(map[string]any)
Expand Down
8 changes: 4 additions & 4 deletions sources/postgres/adapter/adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestPostgresAdapter_TableName(t *testing.T) {
Schema: "schema",
Name: "table1",
}
assert.Equal(t, "table1", NewPostgresAdapter(table).TableName())
assert.Equal(t, "table1", postgresAdapter{table: table}.TableName())
}

func TestPostgresAdapter_TopicSuffix(t *testing.T) {
Expand All @@ -42,7 +42,7 @@ func TestPostgresAdapter_TopicSuffix(t *testing.T) {
}

for _, tc := range tcs {
adapter := NewPostgresAdapter(tc.table)
adapter := postgresAdapter{table: tc.table}
assert.Equal(t, tc.expectedTopicName, adapter.TopicSuffix())
}
}
Expand All @@ -57,7 +57,7 @@ func TestPostgresAdapter_Fields(t *testing.T) {
{Name: "col3", Type: schema.Array},
},
}
adapter := NewPostgresAdapter(table)
adapter := postgresAdapter{table: table}

expected := []debezium.Field{
{Type: "string", FieldName: "col1"},
Expand Down Expand Up @@ -101,7 +101,7 @@ func TestPostgresAdapter_PartitionKey(t *testing.T) {
Name: "tbl1",
PrimaryKeys: tc.keys,
}
adapter := NewPostgresAdapter(table)
adapter := postgresAdapter{table: table}
assert.Equal(t, tc.expected, adapter.PartitionKey(tc.row), tc.name)
}
}
8 changes: 4 additions & 4 deletions sources/postgres/adapter/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestDebeziumTransformer(t *testing.T) {
// test zero batches
{
builder := debezium.NewDebeziumTransformer(
NewPostgresAdapter(table),
postgresAdapter{table: table},
&MockRowIterator{batches: [][]map[string]any{}},
)
assert.False(t, builder.HasNext())
Expand All @@ -60,7 +60,7 @@ func TestDebeziumTransformer(t *testing.T) {
// test an iterator that returns an error
{
builder := debezium.NewDebeziumTransformer(
NewPostgresAdapter(table),
postgresAdapter{table: table},
&ErrorRowIterator{},
)

Expand All @@ -72,7 +72,7 @@ func TestDebeziumTransformer(t *testing.T) {
// test two batches each with two rows
{
builder := debezium.NewDebeziumTransformer(
NewPostgresAdapter(table),
postgresAdapter{table: table},
&MockRowIterator{
batches: [][]map[string]any{
{{"a": "1", "b": "11"}, {"a": "2", "b": "12"}},
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestDebeziumTransformer_NilOptionalSchema(t *testing.T) {
}

builder := debezium.NewDebeziumTransformer(
NewPostgresAdapter(table),
postgresAdapter{table: table},
&MockRowIterator{batches: [][]map[string]any{{rowData}}},
)

Expand Down
16 changes: 6 additions & 10 deletions sources/postgres/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@ import (
"github.com/artie-labs/reader/config"
"github.com/artie-labs/reader/lib/debezium"
"github.com/artie-labs/reader/lib/kafkalib"
"github.com/artie-labs/reader/lib/postgres"
"github.com/artie-labs/reader/lib/rdbms"
"github.com/artie-labs/reader/sources/postgres/adapter"
)

const defaultErrorRetries = 10

type Source struct {
cfg config.PostgreSQL
db *sql.DB
Expand All @@ -46,28 +43,27 @@ func (s *Source) Run(ctx context.Context, writer kafkalib.BatchWriter) error {
logger := slog.With(slog.String("schema", tableCfg.Schema), slog.String("table", tableCfg.Name))
snapshotStartTime := time.Now()

logger.Info("Loading metadata for table")
table, err := postgres.LoadTable(s.db, tableCfg.Schema, tableCfg.Name)
dbzAdapter, err := adapter.NewPostgresAdapter(s.db, *tableCfg)
if err != nil {
return fmt.Errorf("failed to load metadata for table %s.%s: %w", table.Schema, table.Name, err)
return fmt.Errorf("failed to create PostgreSQL adapter: %w", err)
}

scanner, err := postgres.NewScanner(s.db, table, tableCfg.ToScannerConfig(defaultErrorRetries))
scanner, err := dbzAdapter.NewIterator()
if err != nil {
if errors.Is(err, rdbms.ErrNoPkValuesForEmptyTable) {
logger.Info("Table does not contain any rows, skipping...")
continue
} else {
return fmt.Errorf("failed to build scanner for table %s: %w", table.Name, err)
return fmt.Errorf("failed to build scanner for table %s: %w", tableCfg.Name, err)
}
}

logger.Info("Scanning table", slog.Any("batchSize", tableCfg.GetBatchSize()))

dbzTransformer := debezium.NewDebeziumTransformer(adapter.NewPostgresAdapter(*table), &scanner)
dbzTransformer := debezium.NewDebeziumTransformer(dbzAdapter, &scanner)
count, err := writer.WriteIterator(ctx, dbzTransformer)
if err != nil {
return fmt.Errorf("failed to snapshot for table %s: %w", table.Name, err)
return fmt.Errorf("failed to snapshot for table %s: %w", tableCfg.Name, err)
}

logger.Info("Finished snapshotting",
Expand Down

0 comments on commit 182e8c7

Please sign in to comment.