diff --git a/integration_tests/mysql/main.go b/integration_tests/mysql/main.go index 69ee85f2..873fd331 100644 --- a/integration_tests/mysql/main.go +++ b/integration_tests/mysql/main.go @@ -517,16 +517,16 @@ func testTypes(db *sql.DB, dbName string) error { row := rows[0] expectedPartitionKey := map[string]any{"pk": int64(1)} - if !maps.Equal(row.PartitionKey, expectedPartitionKey) { - return fmt.Errorf("partition key %v does not match %v", row.PartitionKey, expectedPartitionKey) + if !maps.Equal(row.PartitionKey(), expectedPartitionKey) { + return fmt.Errorf("partition key %v does not match %v", row.PartitionKey(), expectedPartitionKey) } - valueBytes, err := json.MarshalIndent(row.GetPayload(), "", "\t") + valueBytes, err := json.MarshalIndent(row.Event(), "", "\t") if err != nil { return fmt.Errorf("failed to marshal payload") } - expectedPayload := fmt.Sprintf(expectedPayloadTemplate, utils.GetPayload(row).Payload.Source.TsMs, tempTableName) + expectedPayload := fmt.Sprintf(expectedPayloadTemplate, utils.GetEvent(row).Payload.Source.TsMs, tempTableName) if utils.CheckDifference("payload", expectedPayload, string(valueBytes)) { return fmt.Errorf("payload does not match") } @@ -648,10 +648,10 @@ func testScan(db *sql.DB, dbName string) error { return fmt.Errorf("expected %d rows, got %d, batch size %d", len(expectedPartitionKeys), len(rows), batchSize) } for i, row := range rows { - if !maps.Equal(row.PartitionKey, expectedPartitionKeys[i]) { - return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey, expectedPartitionKeys[i]) + if !maps.Equal(row.PartitionKey(), expectedPartitionKeys[i]) { + return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey(), expectedPartitionKeys[i]) } - textValue := utils.GetPayload(row).Payload.After["c_text_value"] + textValue := utils.GetEvent(row).Payload.After["c_text_value"] if textValue != expectedValues[i] { return fmt.Errorf("row values are different for row %d, batch size %d, %v != %v", i, batchSize, textValue, expectedValues[i]) } diff --git a/integration_tests/postgres/main.go b/integration_tests/postgres/main.go index c2144c33..f51aa793 100644 --- a/integration_tests/postgres/main.go +++ b/integration_tests/postgres/main.go @@ -743,16 +743,16 @@ func testTypes(db *sql.DB) error { row := rows[0] expectedPartitionKey := map[string]any{"pk": int64(1)} - if !maps.Equal(row.PartitionKey, expectedPartitionKey) { - return fmt.Errorf("partition key %v does not match %v", row.PartitionKey, expectedPartitionKey) + if !maps.Equal(row.PartitionKey(), expectedPartitionKey) { + return fmt.Errorf("partition key %v does not match %v", row.PartitionKey(), expectedPartitionKey) } - valueBytes, err := json.MarshalIndent(row.GetPayload(), "", "\t") + valueBytes, err := json.MarshalIndent(row.Event(), "", "\t") if err != nil { return fmt.Errorf("failed to marshal payload") } - expectedPayload := fmt.Sprintf(expectedPayloadTemplate, utils.GetPayload(row).Payload.Source.TsMs, tempTableName) + expectedPayload := fmt.Sprintf(expectedPayloadTemplate, utils.GetEvent(row).Payload.Source.TsMs, tempTableName) if utils.CheckDifference("payload", expectedPayload, string(valueBytes)) { return fmt.Errorf("payload does not match") } @@ -874,10 +874,10 @@ func testScan(db *sql.DB) error { return fmt.Errorf("expected %d rows, got %d, batch size %d", len(expectedPartitionKeys), len(rows), batchSize) } for i, row := range rows { - if !maps.Equal(row.PartitionKey, expectedPartitionKeys[i]) { - return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey, expectedPartitionKeys[i]) + if !maps.Equal(row.PartitionKey(), expectedPartitionKeys[i]) { + return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey(), expectedPartitionKeys[i]) } - textValue := utils.GetPayload(row).Payload.After["c_text_value"] + textValue := utils.GetEvent(row).Payload.After["c_text_value"] if textValue != expectedValues[i] { return fmt.Errorf("row values are different for row %d, batch size %d, %v != %v", i, batchSize, textValue, expectedValues[i]) } diff --git a/integration_tests/utils/utils.go b/integration_tests/utils/utils.go index daed12c8..70d148f5 100644 --- a/integration_tests/utils/utils.go +++ b/integration_tests/utils/utils.go @@ -32,7 +32,7 @@ func ReadTable(db *sql.DB, dbzAdapter transformer.Adapter) ([]lib.RawMessage, er return nil, err } - rows := []lib.RawMessage{} + var rows []lib.RawMessage for dbzTransformer.HasNext() { batch, err := dbzTransformer.Next() if err != nil { @@ -43,12 +43,12 @@ func ReadTable(db *sql.DB, dbzAdapter transformer.Adapter) ([]lib.RawMessage, er return rows, nil } -func GetPayload(message lib.RawMessage) util.SchemaEventPayload { - payloadTyped, ok := message.GetPayload().(util.SchemaEventPayload) +func GetEvent(message lib.RawMessage) util.SchemaEventPayload { + event, ok := message.Event().(*util.SchemaEventPayload) if !ok { - panic("payload is not of type util.SchemaEventPayload") + panic("event is not of type *util.SchemaEventPayload") } - return payloadTyped + return *event } func CheckDifference(name, expected, actual string) bool { diff --git a/lib/debezium/transformer/transformer.go b/lib/debezium/transformer/transformer.go index 17f1391e..7c43e114 100644 --- a/lib/debezium/transformer/transformer.go +++ b/lib/debezium/transformer/transformer.go @@ -91,8 +91,9 @@ func (d *DebeziumTransformer) Next() ([]lib.RawMessage, error) { return nil, fmt.Errorf("failed to create Debezium payload: %w", err) } - result = append(result, lib.NewRawMessage(d.adapter.TopicSuffix(), d.partitionKey(row), payload)) + result = append(result, lib.NewRawMessage(d.adapter.TopicSuffix(), d.partitionKey(row), &payload)) } + return result, nil } diff --git a/lib/debezium/transformer/transformer_test.go b/lib/debezium/transformer/transformer_test.go index 2c7d63c1..bde59052 100644 --- a/lib/debezium/transformer/transformer_test.go +++ b/lib/debezium/transformer/transformer_test.go @@ -106,7 +106,7 @@ func TestDebeziumTransformer_Iteration(t *testing.T) { assert.Len(t, results, 1) rows := results[0] assert.Len(t, rows, 1) - payload, isOk := rows[0].GetPayload().(util.SchemaEventPayload) + payload, isOk := rows[0].Event().(*util.SchemaEventPayload) assert.True(t, isOk) assert.Equal(t, "converted-bar", payload.Payload.After["foo"]) } @@ -138,7 +138,7 @@ func TestDebeziumTransformer_Iteration(t *testing.T) { // First batch rows := results[0] assert.Len(t, rows, 1) - payload, isOk := rows[0].GetPayload().(util.SchemaEventPayload) + payload, isOk := rows[0].Event().(*util.SchemaEventPayload) assert.True(t, isOk) assert.Equal(t, "converted-bar", payload.Payload.After["foo"]) // Second batch @@ -146,7 +146,7 @@ func TestDebeziumTransformer_Iteration(t *testing.T) { // Third batch rows = results[2] assert.Len(t, rows, 1) - payload, isOk = rows[0].GetPayload().(util.SchemaEventPayload) + payload, isOk = rows[0].Event().(*util.SchemaEventPayload) assert.True(t, isOk) assert.Equal(t, "converted-grault", payload.Payload.After["corge"]) } @@ -205,9 +205,9 @@ func TestDebeziumTransformer_Next(t *testing.T) { rows := results[0] assert.Len(t, rows, 1) rawMessage := rows[0] - assert.Equal(t, Row{"foo": "bar", "qux": 12}, rawMessage.PartitionKey) - assert.Equal(t, "im-a-little-topic-suffix", rawMessage.TopicSuffix) - payload, isOk := rawMessage.GetPayload().(util.SchemaEventPayload) + assert.Equal(t, Row{"foo": "bar", "qux": 12}, rawMessage.PartitionKey()) + assert.Equal(t, "im-a-little-topic-suffix", rawMessage.TopicSuffix()) + payload, isOk := rawMessage.Event().(*util.SchemaEventPayload) assert.True(t, isOk) payload.Payload.Source.TsMs = 12345 // Modify source time since it'll be ~now expected := util.SchemaEventPayload( @@ -234,7 +234,7 @@ func TestDebeziumTransformer_Next(t *testing.T) { }, }, ) - assert.Equal(t, expected, payload) + assert.Equal(t, expected, *payload) } } diff --git a/lib/dynamo/message.go b/lib/dynamo/message.go index 0ed0e5a6..65f03a1b 100644 --- a/lib/dynamo/message.go +++ b/lib/dynamo/message.go @@ -84,8 +84,8 @@ func transformNewImage(data map[string]*dynamodb.AttributeValue) map[string]any return transformed } -func (m *Message) artieMessage() util.SchemaEventPayload { - return util.SchemaEventPayload{ +func (m *Message) artieMessage() *util.SchemaEventPayload { + return &util.SchemaEventPayload{ Payload: util.Payload{ After: m.rowData, Source: util.Source{ diff --git a/lib/kafkalib/message.go b/lib/kafkalib/message.go index fe1b49c8..09abaf11 100644 --- a/lib/kafkalib/message.go +++ b/lib/kafkalib/message.go @@ -9,18 +9,18 @@ import ( ) func newMessage(topicPrefix string, rawMessage lib.RawMessage) (kafka.Message, error) { - valueBytes, err := json.Marshal(rawMessage.GetPayload()) + valueBytes, err := json.Marshal(rawMessage.Event()) if err != nil { return kafka.Message{}, err } - keyBytes, err := json.Marshal(rawMessage.PartitionKey) + keyBytes, err := json.Marshal(rawMessage.PartitionKey()) if err != nil { return kafka.Message{}, err } return kafka.Message{ - Topic: fmt.Sprintf("%s.%s", topicPrefix, rawMessage.TopicSuffix), + Topic: fmt.Sprintf("%s.%s", topicPrefix, rawMessage.TopicSuffix()), Key: keyBytes, Value: valueBytes, }, nil diff --git a/lib/kafkalib/message_test.go b/lib/kafkalib/message_test.go index 6437de74..abbd0214 100644 --- a/lib/kafkalib/message_test.go +++ b/lib/kafkalib/message_test.go @@ -12,7 +12,7 @@ func TestNewMessage(t *testing.T) { rawMessage := lib.NewRawMessage( "topic-suffix", map[string]any{"key": "value"}, - util.SchemaEventPayload{ + &util.SchemaEventPayload{ Payload: util.Payload{ After: map[string]any{"a": "b"}, Source: util.Source{ diff --git a/lib/types.go b/lib/types.go index d891ec40..c23232e6 100644 --- a/lib/types.go +++ b/lib/types.go @@ -1,40 +1,31 @@ package lib import ( - "github.com/artie-labs/transfer/lib/cdc/mongo" - "github.com/artie-labs/transfer/lib/cdc/util" + "github.com/artie-labs/transfer/lib/cdc" ) type RawMessage struct { - TopicSuffix string - PartitionKey map[string]any - payload util.SchemaEventPayload - mongoPayload mongo.SchemaEventPayload - - mongo bool + topicSuffix string + partitionKey map[string]any + event cdc.Event } -func NewRawMessage(topicSuffix string, partitionKey map[string]any, payload util.SchemaEventPayload) RawMessage { +func NewRawMessage(topicSuffix string, partitionKey map[string]any, event cdc.Event) RawMessage { return RawMessage{ - TopicSuffix: topicSuffix, - PartitionKey: partitionKey, - payload: payload, + topicSuffix: topicSuffix, + partitionKey: partitionKey, + event: event, } } -func NewMongoMessage(topicSuffix string, partitionKey map[string]any, payload mongo.SchemaEventPayload) RawMessage { - return RawMessage{ - TopicSuffix: topicSuffix, - PartitionKey: partitionKey, - mongoPayload: payload, - mongo: true, - } +func (r RawMessage) TopicSuffix() string { + return r.topicSuffix } -func (r RawMessage) GetPayload() any { - if r.mongo { - return r.mongoPayload - } +func (r RawMessage) PartitionKey() map[string]any { + return r.partitionKey +} - return r.payload +func (r RawMessage) Event() cdc.Event { + return r.event } diff --git a/lib/writer/writer_test.go b/lib/writer/writer_test.go index a32f49d1..ef8c8320 100644 --- a/lib/writer/writer_test.go +++ b/lib/writer/writer_test.go @@ -68,23 +68,26 @@ func TestWriter_Write(t *testing.T) { destination := &mockDestination{} writer := New(destination, false) iter := iterator.ForSlice([][]lib.RawMessage{ - {{TopicSuffix: "a"}}, + {lib.NewRawMessage("a", nil, nil)}, {}, - {{TopicSuffix: "b"}, {TopicSuffix: "c"}}, + { + lib.NewRawMessage("b", nil, nil), + lib.NewRawMessage("c", nil, nil), + }, }) count, err := writer.Write(context.Background(), iter) assert.NoError(t, err) assert.Equal(t, 3, count) assert.Len(t, destination.messages, 3) - assert.Equal(t, destination.messages[0].TopicSuffix, "a") - assert.Equal(t, destination.messages[1].TopicSuffix, "b") - assert.Equal(t, destination.messages[2].TopicSuffix, "c") + assert.Equal(t, destination.messages[0].TopicSuffix(), "a") + assert.Equal(t, destination.messages[1].TopicSuffix(), "b") + assert.Equal(t, destination.messages[2].TopicSuffix(), "c") } { // Destination error destination := &mockDestination{emitError: true} writer := New(destination, false) - iter := iterator.Once([]lib.RawMessage{{TopicSuffix: "a"}}) + iter := iterator.Once([]lib.RawMessage{lib.NewRawMessage("a", nil, nil)}) _, err := writer.Write(context.Background(), iter) assert.ErrorContains(t, err, "failed to write messages: test write-raw-messages error") assert.Empty(t, destination.messages) diff --git a/sources/mongo/message.go b/sources/mongo/message.go index 0f9ac275..36576067 100644 --- a/sources/mongo/message.go +++ b/sources/mongo/message.go @@ -18,7 +18,7 @@ type mgoMessage struct { } func (m *mgoMessage) toRawMessage(collection config.Collection, database string) (lib.RawMessage, error) { - evt := mongo.SchemaEventPayload{ + evt := &mongo.SchemaEventPayload{ Schema: debezium.Schema{}, Payload: mongo.Payload{ After: &m.jsonExtendedString, @@ -35,7 +35,7 @@ func (m *mgoMessage) toRawMessage(collection config.Collection, database string) "payload": m.pkMap, } - return lib.NewMongoMessage(collection.TopicSuffix(database), pkMap, evt), nil + return lib.NewRawMessage(collection.TopicSuffix(database), pkMap, evt), nil } func parseMessage(result bson.M) (*mgoMessage, error) { diff --git a/sources/mongo/message_test.go b/sources/mongo/message_test.go index dc4856f4..cf563e15 100644 --- a/sources/mongo/message_test.go +++ b/sources/mongo/message_test.go @@ -25,7 +25,7 @@ func TestParseMessagePartitionKey(t *testing.T) { rawMsg, err := msg.toRawMessage(config.Collection{}, "database") assert.NoError(t, err) - rawMsgBytes, err := json.Marshal(rawMsg.PartitionKey) + rawMsgBytes, err := json.Marshal(rawMsg.PartitionKey()) assert.NoError(t, err) var dbz transferMongo.Debezium @@ -67,14 +67,14 @@ func TestParseMessage(t *testing.T) { rawMsg, err := msg.toRawMessage(config.Collection{}, "database") assert.NoError(t, err) - rawPkBytes, err := json.Marshal(rawMsg.PartitionKey) + rawPkBytes, err := json.Marshal(rawMsg.PartitionKey()) assert.NoError(t, err) var dbz transferMongo.Debezium pkMap, err := dbz.GetPrimaryKey(rawPkBytes, &kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) assert.NoError(t, err) - rawMsgBytes, err := json.Marshal(rawMsg.GetPayload()) + rawMsgBytes, err := json.Marshal(rawMsg.Event()) assert.NoError(t, err) kvMap, err := dbz.GetEventFromBytes(typing.Settings{}, rawMsgBytes) assert.NoError(t, err) diff --git a/sources/postgres/adapter/transformer_test.go b/sources/postgres/adapter/transformer_test.go index 5fcbbf94..f31ca40a 100644 --- a/sources/postgres/adapter/transformer_test.go +++ b/sources/postgres/adapter/transformer_test.go @@ -2,9 +2,9 @@ package adapter import ( "fmt" + "github.com/artie-labs/transfer/lib/cdc/util" "testing" - "github.com/artie-labs/transfer/lib/cdc/util" "github.com/stretchr/testify/assert" "github.com/artie-labs/reader/lib/debezium/converters" @@ -78,21 +78,21 @@ func TestDebeziumTransformer(t *testing.T) { msgs1 := results[0] assert.Len(t, msgs1, 2) - assert.Equal(t, "schema.table", msgs1[0].TopicSuffix) - assert.Equal(t, map[string]any{"a": "1"}, msgs1[0].PartitionKey) - assert.Equal(t, map[string]any{"a": "1", "b": "11"}, msgs1[0].GetPayload().(util.SchemaEventPayload).Payload.After) - assert.Equal(t, "schema.table", msgs1[1].TopicSuffix) - assert.Equal(t, map[string]any{"a": "2"}, msgs1[1].PartitionKey) - assert.Equal(t, map[string]any{"a": "2", "b": "12"}, msgs1[1].GetPayload().(util.SchemaEventPayload).Payload.After) + assert.Equal(t, "schema.table", msgs1[0].TopicSuffix()) + assert.Equal(t, map[string]any{"a": "1"}, msgs1[0].PartitionKey()) + assert.Equal(t, map[string]any{"a": "1", "b": "11"}, msgs1[0].Event().(*util.SchemaEventPayload).Payload.After) + assert.Equal(t, "schema.table", msgs1[1].TopicSuffix()) + assert.Equal(t, map[string]any{"a": "2"}, msgs1[1].PartitionKey()) + assert.Equal(t, map[string]any{"a": "2", "b": "12"}, msgs1[1].Event().(*util.SchemaEventPayload).Payload.After) msgs2 := results[1] assert.Len(t, msgs2, 2) - assert.Equal(t, "schema.table", msgs2[0].TopicSuffix) - assert.Equal(t, map[string]any{"a": "3"}, msgs2[0].PartitionKey) - assert.Equal(t, map[string]any{"a": "3", "b": "13"}, msgs2[0].GetPayload().(util.SchemaEventPayload).Payload.After) - assert.Equal(t, "schema.table", msgs2[1].TopicSuffix) - assert.Equal(t, map[string]any{"a": "4"}, msgs2[1].PartitionKey) - assert.Equal(t, map[string]any{"a": "4", "b": "14"}, msgs2[1].GetPayload().(util.SchemaEventPayload).Payload.After) + assert.Equal(t, "schema.table", msgs2[0].TopicSuffix()) + assert.Equal(t, map[string]any{"a": "3"}, msgs2[0].PartitionKey()) + assert.Equal(t, map[string]any{"a": "3", "b": "13"}, msgs2[0].Event().(*util.SchemaEventPayload).Payload.After) + assert.Equal(t, "schema.table", msgs2[1].TopicSuffix()) + assert.Equal(t, map[string]any{"a": "4"}, msgs2[1].PartitionKey()) + assert.Equal(t, map[string]any{"a": "4", "b": "14"}, msgs2[1].Event().(*util.SchemaEventPayload).Payload.After) } } @@ -127,7 +127,7 @@ func TestDebeziumTransformer_NilOptionalSchema(t *testing.T) { assert.Len(t, results, 1) rows := results[0] assert.Len(t, rows, 1) - payload := rows[0].GetPayload().(util.SchemaEventPayload) + payload := rows[0].Event().(*util.SchemaEventPayload) assert.Equal(t, "r", payload.Payload.Operation) assert.Equal(t, rowData, payload.Payload.After)