Skip to content

Commit

Permalink
Return cdc.Event instead of any (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Apr 3, 2024
1 parent d43233c commit c0a39c8
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 82 deletions.
14 changes: 7 additions & 7 deletions integration_tests/mysql/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,16 @@ func testTypes(db *sql.DB, dbName string) error {
row := rows[0]

expectedPartitionKey := map[string]any{"pk": int64(1)}
if !maps.Equal(row.PartitionKey, expectedPartitionKey) {
return fmt.Errorf("partition key %v does not match %v", row.PartitionKey, expectedPartitionKey)
if !maps.Equal(row.PartitionKey(), expectedPartitionKey) {
return fmt.Errorf("partition key %v does not match %v", row.PartitionKey(), expectedPartitionKey)
}

valueBytes, err := json.MarshalIndent(row.GetPayload(), "", "\t")
valueBytes, err := json.MarshalIndent(row.Event(), "", "\t")
if err != nil {
return fmt.Errorf("failed to marshal payload")
}

expectedPayload := fmt.Sprintf(expectedPayloadTemplate, utils.GetPayload(row).Payload.Source.TsMs, tempTableName)
expectedPayload := fmt.Sprintf(expectedPayloadTemplate, utils.GetEvent(row).Payload.Source.TsMs, tempTableName)
if utils.CheckDifference("payload", expectedPayload, string(valueBytes)) {
return fmt.Errorf("payload does not match")
}
Expand Down Expand Up @@ -648,10 +648,10 @@ func testScan(db *sql.DB, dbName string) error {
return fmt.Errorf("expected %d rows, got %d, batch size %d", len(expectedPartitionKeys), len(rows), batchSize)
}
for i, row := range rows {
if !maps.Equal(row.PartitionKey, expectedPartitionKeys[i]) {
return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey, expectedPartitionKeys[i])
if !maps.Equal(row.PartitionKey(), expectedPartitionKeys[i]) {
return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey(), expectedPartitionKeys[i])
}
textValue := utils.GetPayload(row).Payload.After["c_text_value"]
textValue := utils.GetEvent(row).Payload.After["c_text_value"]
if textValue != expectedValues[i] {
return fmt.Errorf("row values are different for row %d, batch size %d, %v != %v", i, batchSize, textValue, expectedValues[i])
}
Expand Down
14 changes: 7 additions & 7 deletions integration_tests/postgres/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,16 +743,16 @@ func testTypes(db *sql.DB) error {
row := rows[0]

expectedPartitionKey := map[string]any{"pk": int64(1)}
if !maps.Equal(row.PartitionKey, expectedPartitionKey) {
return fmt.Errorf("partition key %v does not match %v", row.PartitionKey, expectedPartitionKey)
if !maps.Equal(row.PartitionKey(), expectedPartitionKey) {
return fmt.Errorf("partition key %v does not match %v", row.PartitionKey(), expectedPartitionKey)
}

valueBytes, err := json.MarshalIndent(row.GetPayload(), "", "\t")
valueBytes, err := json.MarshalIndent(row.Event(), "", "\t")
if err != nil {
return fmt.Errorf("failed to marshal payload")
}

expectedPayload := fmt.Sprintf(expectedPayloadTemplate, utils.GetPayload(row).Payload.Source.TsMs, tempTableName)
expectedPayload := fmt.Sprintf(expectedPayloadTemplate, utils.GetEvent(row).Payload.Source.TsMs, tempTableName)
if utils.CheckDifference("payload", expectedPayload, string(valueBytes)) {
return fmt.Errorf("payload does not match")
}
Expand Down Expand Up @@ -874,10 +874,10 @@ func testScan(db *sql.DB) error {
return fmt.Errorf("expected %d rows, got %d, batch size %d", len(expectedPartitionKeys), len(rows), batchSize)
}
for i, row := range rows {
if !maps.Equal(row.PartitionKey, expectedPartitionKeys[i]) {
return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey, expectedPartitionKeys[i])
if !maps.Equal(row.PartitionKey(), expectedPartitionKeys[i]) {
return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey(), expectedPartitionKeys[i])
}
textValue := utils.GetPayload(row).Payload.After["c_text_value"]
textValue := utils.GetEvent(row).Payload.After["c_text_value"]
if textValue != expectedValues[i] {
return fmt.Errorf("row values are different for row %d, batch size %d, %v != %v", i, batchSize, textValue, expectedValues[i])
}
Expand Down
10 changes: 5 additions & 5 deletions integration_tests/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func ReadTable(db *sql.DB, dbzAdapter transformer.Adapter) ([]lib.RawMessage, er
return nil, err
}

rows := []lib.RawMessage{}
var rows []lib.RawMessage
for dbzTransformer.HasNext() {
batch, err := dbzTransformer.Next()
if err != nil {
Expand All @@ -43,12 +43,12 @@ func ReadTable(db *sql.DB, dbzAdapter transformer.Adapter) ([]lib.RawMessage, er
return rows, nil
}

func GetPayload(message lib.RawMessage) util.SchemaEventPayload {
payloadTyped, ok := message.GetPayload().(util.SchemaEventPayload)
func GetEvent(message lib.RawMessage) util.SchemaEventPayload {
event, ok := message.Event().(*util.SchemaEventPayload)
if !ok {
panic("payload is not of type util.SchemaEventPayload")
panic("event is not of type *util.SchemaEventPayload")
}
return payloadTyped
return *event
}

func CheckDifference(name, expected, actual string) bool {
Expand Down
3 changes: 2 additions & 1 deletion lib/debezium/transformer/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ func (d *DebeziumTransformer) Next() ([]lib.RawMessage, error) {
return nil, fmt.Errorf("failed to create Debezium payload: %w", err)
}

result = append(result, lib.NewRawMessage(d.adapter.TopicSuffix(), d.partitionKey(row), payload))
result = append(result, lib.NewRawMessage(d.adapter.TopicSuffix(), d.partitionKey(row), &payload))
}

return result, nil
}

Expand Down
14 changes: 7 additions & 7 deletions lib/debezium/transformer/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestDebeziumTransformer_Iteration(t *testing.T) {
assert.Len(t, results, 1)
rows := results[0]
assert.Len(t, rows, 1)
payload, isOk := rows[0].GetPayload().(util.SchemaEventPayload)
payload, isOk := rows[0].Event().(*util.SchemaEventPayload)
assert.True(t, isOk)
assert.Equal(t, "converted-bar", payload.Payload.After["foo"])
}
Expand Down Expand Up @@ -138,15 +138,15 @@ func TestDebeziumTransformer_Iteration(t *testing.T) {
// First batch
rows := results[0]
assert.Len(t, rows, 1)
payload, isOk := rows[0].GetPayload().(util.SchemaEventPayload)
payload, isOk := rows[0].Event().(*util.SchemaEventPayload)
assert.True(t, isOk)
assert.Equal(t, "converted-bar", payload.Payload.After["foo"])
// Second batch
assert.Empty(t, results[1], 0)
// Third batch
rows = results[2]
assert.Len(t, rows, 1)
payload, isOk = rows[0].GetPayload().(util.SchemaEventPayload)
payload, isOk = rows[0].Event().(*util.SchemaEventPayload)
assert.True(t, isOk)
assert.Equal(t, "converted-grault", payload.Payload.After["corge"])
}
Expand Down Expand Up @@ -205,9 +205,9 @@ func TestDebeziumTransformer_Next(t *testing.T) {
rows := results[0]
assert.Len(t, rows, 1)
rawMessage := rows[0]
assert.Equal(t, Row{"foo": "bar", "qux": 12}, rawMessage.PartitionKey)
assert.Equal(t, "im-a-little-topic-suffix", rawMessage.TopicSuffix)
payload, isOk := rawMessage.GetPayload().(util.SchemaEventPayload)
assert.Equal(t, Row{"foo": "bar", "qux": 12}, rawMessage.PartitionKey())
assert.Equal(t, "im-a-little-topic-suffix", rawMessage.TopicSuffix())
payload, isOk := rawMessage.Event().(*util.SchemaEventPayload)
assert.True(t, isOk)
payload.Payload.Source.TsMs = 12345 // Modify source time since it'll be ~now
expected := util.SchemaEventPayload(
Expand All @@ -234,7 +234,7 @@ func TestDebeziumTransformer_Next(t *testing.T) {
},
},
)
assert.Equal(t, expected, payload)
assert.Equal(t, expected, *payload)
}
}

Expand Down
4 changes: 2 additions & 2 deletions lib/dynamo/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ func transformNewImage(data map[string]*dynamodb.AttributeValue) map[string]any
return transformed
}

func (m *Message) artieMessage() util.SchemaEventPayload {
return util.SchemaEventPayload{
func (m *Message) artieMessage() *util.SchemaEventPayload {
return &util.SchemaEventPayload{
Payload: util.Payload{
After: m.rowData,
Source: util.Source{
Expand Down
6 changes: 3 additions & 3 deletions lib/kafkalib/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ import (
)

func newMessage(topicPrefix string, rawMessage lib.RawMessage) (kafka.Message, error) {
valueBytes, err := json.Marshal(rawMessage.GetPayload())
valueBytes, err := json.Marshal(rawMessage.Event())
if err != nil {
return kafka.Message{}, err
}

keyBytes, err := json.Marshal(rawMessage.PartitionKey)
keyBytes, err := json.Marshal(rawMessage.PartitionKey())
if err != nil {
return kafka.Message{}, err
}

return kafka.Message{
Topic: fmt.Sprintf("%s.%s", topicPrefix, rawMessage.TopicSuffix),
Topic: fmt.Sprintf("%s.%s", topicPrefix, rawMessage.TopicSuffix()),
Key: keyBytes,
Value: valueBytes,
}, nil
Expand Down
2 changes: 1 addition & 1 deletion lib/kafkalib/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func TestNewMessage(t *testing.T) {
rawMessage := lib.NewRawMessage(
"topic-suffix",
map[string]any{"key": "value"},
util.SchemaEventPayload{
&util.SchemaEventPayload{
Payload: util.Payload{
After: map[string]any{"a": "b"},
Source: util.Source{
Expand Down
39 changes: 15 additions & 24 deletions lib/types.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,31 @@
package lib

import (
"github.com/artie-labs/transfer/lib/cdc/mongo"
"github.com/artie-labs/transfer/lib/cdc/util"
"github.com/artie-labs/transfer/lib/cdc"
)

type RawMessage struct {
TopicSuffix string
PartitionKey map[string]any
payload util.SchemaEventPayload
mongoPayload mongo.SchemaEventPayload

mongo bool
topicSuffix string
partitionKey map[string]any
event cdc.Event
}

func NewRawMessage(topicSuffix string, partitionKey map[string]any, payload util.SchemaEventPayload) RawMessage {
func NewRawMessage(topicSuffix string, partitionKey map[string]any, event cdc.Event) RawMessage {
return RawMessage{
TopicSuffix: topicSuffix,
PartitionKey: partitionKey,
payload: payload,
topicSuffix: topicSuffix,
partitionKey: partitionKey,
event: event,
}
}

func NewMongoMessage(topicSuffix string, partitionKey map[string]any, payload mongo.SchemaEventPayload) RawMessage {
return RawMessage{
TopicSuffix: topicSuffix,
PartitionKey: partitionKey,
mongoPayload: payload,
mongo: true,
}
func (r RawMessage) TopicSuffix() string {
return r.topicSuffix
}

func (r RawMessage) GetPayload() any {
if r.mongo {
return r.mongoPayload
}
func (r RawMessage) PartitionKey() map[string]any {
return r.partitionKey
}

return r.payload
func (r RawMessage) Event() cdc.Event {
return r.event
}
15 changes: 9 additions & 6 deletions lib/writer/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,26 @@ func TestWriter_Write(t *testing.T) {
destination := &mockDestination{}
writer := New(destination, false)
iter := iterator.ForSlice([][]lib.RawMessage{
{{TopicSuffix: "a"}},
{lib.NewRawMessage("a", nil, nil)},
{},
{{TopicSuffix: "b"}, {TopicSuffix: "c"}},
{
lib.NewRawMessage("b", nil, nil),
lib.NewRawMessage("c", nil, nil),
},
})
count, err := writer.Write(context.Background(), iter)
assert.NoError(t, err)
assert.Equal(t, 3, count)
assert.Len(t, destination.messages, 3)
assert.Equal(t, destination.messages[0].TopicSuffix, "a")
assert.Equal(t, destination.messages[1].TopicSuffix, "b")
assert.Equal(t, destination.messages[2].TopicSuffix, "c")
assert.Equal(t, destination.messages[0].TopicSuffix(), "a")
assert.Equal(t, destination.messages[1].TopicSuffix(), "b")
assert.Equal(t, destination.messages[2].TopicSuffix(), "c")
}
{
// Destination error
destination := &mockDestination{emitError: true}
writer := New(destination, false)
iter := iterator.Once([]lib.RawMessage{{TopicSuffix: "a"}})
iter := iterator.Once([]lib.RawMessage{lib.NewRawMessage("a", nil, nil)})
_, err := writer.Write(context.Background(), iter)
assert.ErrorContains(t, err, "failed to write messages: test write-raw-messages error")
assert.Empty(t, destination.messages)
Expand Down
4 changes: 2 additions & 2 deletions sources/mongo/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type mgoMessage struct {
}

func (m *mgoMessage) toRawMessage(collection config.Collection, database string) (lib.RawMessage, error) {
evt := mongo.SchemaEventPayload{
evt := &mongo.SchemaEventPayload{
Schema: debezium.Schema{},
Payload: mongo.Payload{
After: &m.jsonExtendedString,
Expand All @@ -35,7 +35,7 @@ func (m *mgoMessage) toRawMessage(collection config.Collection, database string)
"payload": m.pkMap,
}

return lib.NewMongoMessage(collection.TopicSuffix(database), pkMap, evt), nil
return lib.NewRawMessage(collection.TopicSuffix(database), pkMap, evt), nil
}

func parseMessage(result bson.M) (*mgoMessage, error) {
Expand Down
6 changes: 3 additions & 3 deletions sources/mongo/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestParseMessagePartitionKey(t *testing.T) {
rawMsg, err := msg.toRawMessage(config.Collection{}, "database")
assert.NoError(t, err)

rawMsgBytes, err := json.Marshal(rawMsg.PartitionKey)
rawMsgBytes, err := json.Marshal(rawMsg.PartitionKey())
assert.NoError(t, err)

var dbz transferMongo.Debezium
Expand Down Expand Up @@ -67,14 +67,14 @@ func TestParseMessage(t *testing.T) {
rawMsg, err := msg.toRawMessage(config.Collection{}, "database")
assert.NoError(t, err)

rawPkBytes, err := json.Marshal(rawMsg.PartitionKey)
rawPkBytes, err := json.Marshal(rawMsg.PartitionKey())
assert.NoError(t, err)

var dbz transferMongo.Debezium
pkMap, err := dbz.GetPrimaryKey(rawPkBytes, &kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt})
assert.NoError(t, err)

rawMsgBytes, err := json.Marshal(rawMsg.GetPayload())
rawMsgBytes, err := json.Marshal(rawMsg.Event())
assert.NoError(t, err)
kvMap, err := dbz.GetEventFromBytes(typing.Settings{}, rawMsgBytes)
assert.NoError(t, err)
Expand Down
28 changes: 14 additions & 14 deletions sources/postgres/adapter/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package adapter

import (
"fmt"
"github.com/artie-labs/transfer/lib/cdc/util"
"testing"

"github.com/artie-labs/transfer/lib/cdc/util"
"github.com/stretchr/testify/assert"

"github.com/artie-labs/reader/lib/debezium/converters"
Expand Down Expand Up @@ -78,21 +78,21 @@ func TestDebeziumTransformer(t *testing.T) {

msgs1 := results[0]
assert.Len(t, msgs1, 2)
assert.Equal(t, "schema.table", msgs1[0].TopicSuffix)
assert.Equal(t, map[string]any{"a": "1"}, msgs1[0].PartitionKey)
assert.Equal(t, map[string]any{"a": "1", "b": "11"}, msgs1[0].GetPayload().(util.SchemaEventPayload).Payload.After)
assert.Equal(t, "schema.table", msgs1[1].TopicSuffix)
assert.Equal(t, map[string]any{"a": "2"}, msgs1[1].PartitionKey)
assert.Equal(t, map[string]any{"a": "2", "b": "12"}, msgs1[1].GetPayload().(util.SchemaEventPayload).Payload.After)
assert.Equal(t, "schema.table", msgs1[0].TopicSuffix())
assert.Equal(t, map[string]any{"a": "1"}, msgs1[0].PartitionKey())
assert.Equal(t, map[string]any{"a": "1", "b": "11"}, msgs1[0].Event().(*util.SchemaEventPayload).Payload.After)
assert.Equal(t, "schema.table", msgs1[1].TopicSuffix())
assert.Equal(t, map[string]any{"a": "2"}, msgs1[1].PartitionKey())
assert.Equal(t, map[string]any{"a": "2", "b": "12"}, msgs1[1].Event().(*util.SchemaEventPayload).Payload.After)

msgs2 := results[1]
assert.Len(t, msgs2, 2)
assert.Equal(t, "schema.table", msgs2[0].TopicSuffix)
assert.Equal(t, map[string]any{"a": "3"}, msgs2[0].PartitionKey)
assert.Equal(t, map[string]any{"a": "3", "b": "13"}, msgs2[0].GetPayload().(util.SchemaEventPayload).Payload.After)
assert.Equal(t, "schema.table", msgs2[1].TopicSuffix)
assert.Equal(t, map[string]any{"a": "4"}, msgs2[1].PartitionKey)
assert.Equal(t, map[string]any{"a": "4", "b": "14"}, msgs2[1].GetPayload().(util.SchemaEventPayload).Payload.After)
assert.Equal(t, "schema.table", msgs2[0].TopicSuffix())
assert.Equal(t, map[string]any{"a": "3"}, msgs2[0].PartitionKey())
assert.Equal(t, map[string]any{"a": "3", "b": "13"}, msgs2[0].Event().(*util.SchemaEventPayload).Payload.After)
assert.Equal(t, "schema.table", msgs2[1].TopicSuffix())
assert.Equal(t, map[string]any{"a": "4"}, msgs2[1].PartitionKey())
assert.Equal(t, map[string]any{"a": "4", "b": "14"}, msgs2[1].Event().(*util.SchemaEventPayload).Payload.After)
}
}

Expand Down Expand Up @@ -127,7 +127,7 @@ func TestDebeziumTransformer_NilOptionalSchema(t *testing.T) {
assert.Len(t, results, 1)
rows := results[0]
assert.Len(t, rows, 1)
payload := rows[0].GetPayload().(util.SchemaEventPayload)
payload := rows[0].Event().(*util.SchemaEventPayload)

assert.Equal(t, "r", payload.Payload.Operation)
assert.Equal(t, rowData, payload.Payload.After)
Expand Down

0 comments on commit c0a39c8

Please sign in to comment.