diff --git a/go.mod b/go.mod index 52dcd16c..745b1d27 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/DataDog/datadog-go/v5 v5.5.0 - github.com/artie-labs/transfer v1.25.18 + github.com/artie-labs/transfer v1.25.23 github.com/aws/aws-sdk-go v1.44.327 github.com/aws/aws-sdk-go-v2 v1.18.1 github.com/aws/aws-sdk-go-v2/config v1.18.19 @@ -149,6 +149,6 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20240415180920-8c6c420018be // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be // indirect google.golang.org/grpc v1.63.2 // indirect - google.golang.org/protobuf v1.33.0 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index b7f8ec80..5a6356d9 100644 --- a/go.sum +++ b/go.sum @@ -95,8 +95,8 @@ github.com/apache/thrift v0.0.0-20181112125854-24918abba929/go.mod h1:cp2SuWMxlE github.com/apache/thrift v0.14.2/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.17.0 h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo= github.com/apache/thrift v0.17.0/go.mod h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q= -github.com/artie-labs/transfer v1.25.18 h1:TMYAKPn1PDFQ9HieogdrKuAnE7aDmxAEKtVtcLdOEM0= -github.com/artie-labs/transfer v1.25.18/go.mod h1:lv9NtzWvCcG4haLbdUAvcvEKfChbeam5fDATr26ve88= +github.com/artie-labs/transfer v1.25.23 h1:9pL6wO/87H6vmFyN8NfOuLAP28w3CjmxeELIrvWKePk= +github.com/artie-labs/transfer v1.25.23/go.mod h1:PxZjjW1+OnZDgRRJwVXUoiGY2iPsLnY2TUMrdcY3zfg= github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go v1.44.327 h1:ZS8oO4+7MOBLhkdwIhgtVeDzCeWOlTfKJS7EgggbIEY= github.com/aws/aws-sdk-go v1.44.327/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= @@ -832,8 +832,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/lib/debezium/converters/decimal_test.go b/lib/debezium/converters/decimal_test.go index b947baca..64161890 100644 --- a/lib/debezium/converters/decimal_test.go +++ b/lib/debezium/converters/decimal_test.go @@ -1,133 +1,133 @@ package converters import ( - "fmt" - "testing" + "fmt" + "testing" - "github.com/artie-labs/transfer/lib/debezium" - "github.com/artie-labs/transfer/lib/ptr" - "github.com/stretchr/testify/assert" + "github.com/artie-labs/transfer/lib/debezium" + "github.com/artie-labs/transfer/lib/ptr" + "github.com/stretchr/testify/assert" ) func TestDecimalConverter_ToField(t *testing.T) { - { - // Without precision - converter := NewDecimalConverter(2, nil) - expected := debezium.Field{ - Type: "bytes", - FieldName: "col", - DebeziumType: "org.apache.kafka.connect.data.Decimal", - Parameters: map[string]any{ - "scale": "2", - }, - } - assert.Equal(t, expected, converter.ToField("col")) - } - { - // With precision - converter := NewDecimalConverter(2, ptr.ToInt(3)) - expected := debezium.Field{ - Type: "bytes", - FieldName: "col", - DebeziumType: "org.apache.kafka.connect.data.Decimal", - Parameters: map[string]any{ - "connect.decimal.precision": "3", - "scale": "2", - }, - } - assert.Equal(t, expected, converter.ToField("col")) - } + { + // Without precision + converter := NewDecimalConverter(2, nil) + expected := debezium.Field{ + Type: "bytes", + FieldName: "col", + DebeziumType: "org.apache.kafka.connect.data.Decimal", + Parameters: map[string]any{ + "scale": "2", + }, + } + assert.Equal(t, expected, converter.ToField("col")) + } + { + // With precision + converter := NewDecimalConverter(2, ptr.ToInt(3)) + expected := debezium.Field{ + Type: "bytes", + FieldName: "col", + DebeziumType: "org.apache.kafka.connect.data.Decimal", + Parameters: map[string]any{ + "connect.decimal.precision": "3", + "scale": "2", + }, + } + assert.Equal(t, expected, converter.ToField("col")) + } } func TestDecimalConverter_Convert(t *testing.T) { - converter := NewDecimalConverter(2, nil) - { - // Malformed value - empty string. - _, err := converter.Convert("") - assert.ErrorContains(t, err, "unable to use '' as a floating-point number") - } - { - // Malformed value - not a floating-point. - _, err := converter.Convert("11qwerty00") - assert.ErrorContains(t, err, "unable to use '11qwerty00' as a floating-point number") - } - { - // Happy path. - converted, err := converter.Convert("1.23") - assert.NoError(t, err) - bytes, ok := converted.([]byte) - assert.True(t, ok) - actualValue, err := converter.ToField("").DecodeDecimal(bytes) - assert.NoError(t, err) - assert.Equal(t, "1.23", fmt.Sprint(actualValue)) - } + converter := NewDecimalConverter(2, nil) + { + // Malformed value - empty string. + _, err := converter.Convert("") + assert.ErrorContains(t, err, `unable to use "" as a floating-point number`) + } + { + // Malformed value - not a floating-point. + _, err := converter.Convert("11qwerty00") + assert.ErrorContains(t, err, `unable to use "11qwerty00" as a floating-point number`) + } + { + // Happy path. + converted, err := converter.Convert("1.23") + assert.NoError(t, err) + bytes, ok := converted.([]byte) + assert.True(t, ok) + actualValue, err := converter.ToField("").DecodeDecimal(bytes) + assert.NoError(t, err) + assert.Equal(t, "1.23", fmt.Sprint(actualValue)) + } } func TestGetScale(t *testing.T) { - type _testCase struct { - name string - value string - expectedScale int - } + type _testCase struct { + name string + value string + expectedScale int + } - testCases := []_testCase{ - { - name: "0 scale", - value: "5", - expectedScale: 0, - }, - { - name: "2 scale", - value: "9.99", - expectedScale: 2, - }, - { - name: "5 scale", - value: "9.12345", - expectedScale: 5, - }, - } + testCases := []_testCase{ + { + name: "0 scale", + value: "5", + expectedScale: 0, + }, + { + name: "2 scale", + value: "9.99", + expectedScale: 2, + }, + { + name: "5 scale", + value: "9.12345", + expectedScale: 5, + }, + } - for _, testCase := range testCases { - actualScale := getScale(testCase.value) - assert.Equal(t, testCase.expectedScale, actualScale, testCase.name) - } + for _, testCase := range testCases { + actualScale := getScale(testCase.value) + assert.Equal(t, testCase.expectedScale, actualScale, testCase.name) + } } func TestVariableNumericConverter_ToField(t *testing.T) { - converter := VariableNumericConverter{} - expected := debezium.Field{ - FieldName: "col", - Type: "struct", - DebeziumType: "io.debezium.data.VariableScaleDecimal", - } - assert.Equal(t, expected, converter.ToField("col")) + converter := VariableNumericConverter{} + expected := debezium.Field{ + FieldName: "col", + Type: "struct", + DebeziumType: "io.debezium.data.VariableScaleDecimal", + } + assert.Equal(t, expected, converter.ToField("col")) } func TestVariableNumericConverter_Convert(t *testing.T) { - converter := VariableNumericConverter{} - { - // Wrong type - _, err := converter.Convert(1234) - assert.ErrorContains(t, err, "expected string got int with value: 1234") - } - { - // Malformed value - emty string. - _, err := converter.Convert("") - assert.ErrorContains(t, err, "unable to use '' as a floating-point number") - } - { - // Malformed value - not a floating point. - _, err := converter.Convert("malformed") - assert.ErrorContains(t, err, "unable to use 'malformed' as a floating-point number") - } - { - // Happy path - converted, err := converter.Convert("12.34") - assert.NoError(t, err) - assert.Equal(t, map[string]any{"scale": int32(2), "value": []byte{0x4, 0xd2}}, converted) - actualValue, err := converter.ToField("").DecodeDebeziumVariableDecimal(converted) - assert.NoError(t, err) - assert.Equal(t, "12.34", actualValue.String()) - } + converter := VariableNumericConverter{} + { + // Wrong type + _, err := converter.Convert(1234) + assert.ErrorContains(t, err, "expected string got int with value: 1234") + } + { + // Malformed value - empty string. + _, err := converter.Convert("") + assert.ErrorContains(t, err, `unable to use "" as a floating-point number`) + } + { + // Malformed value - not a floating point. + _, err := converter.Convert("malformed") + assert.ErrorContains(t, err, `unable to use "malformed" as a floating-point number`) + } + { + // Happy path + converted, err := converter.Convert("12.34") + assert.NoError(t, err) + assert.Equal(t, map[string]any{"scale": int32(2), "value": []byte{0x4, 0xd2}}, converted) + actualValue, err := converter.ToField("").DecodeDebeziumVariableDecimal(converted) + assert.NoError(t, err) + assert.Equal(t, "12.34", actualValue.String()) + } } diff --git a/lib/debezium/converters/money_test.go b/lib/debezium/converters/money_test.go index f476e075..dc72bc5c 100644 --- a/lib/debezium/converters/money_test.go +++ b/lib/debezium/converters/money_test.go @@ -1,97 +1,97 @@ package converters import ( - transferDbz "github.com/artie-labs/transfer/lib/debezium" - "github.com/artie-labs/transfer/lib/ptr" - "github.com/stretchr/testify/assert" - "testing" + transferDbz "github.com/artie-labs/transfer/lib/debezium" + "github.com/artie-labs/transfer/lib/ptr" + "github.com/stretchr/testify/assert" + "testing" ) func TestMoney_Scale(t *testing.T) { - { - // Not specified - converter := MoneyConverter{} - assert.Equal(t, defaultScale, converter.Scale()) - } - { - // Specified - converter := MoneyConverter{ - ScaleOverride: ptr.ToInt(3), - } - assert.Equal(t, 3, converter.Scale()) - } + { + // Not specified + converter := MoneyConverter{} + assert.Equal(t, defaultScale, converter.Scale()) + } + { + // Specified + converter := MoneyConverter{ + ScaleOverride: ptr.ToInt(3), + } + assert.Equal(t, 3, converter.Scale()) + } } func TestMoneyConverter_ToField(t *testing.T) { - converter := MoneyConverter{} - expected := transferDbz.Field{ - FieldName: "col", - Type: "bytes", - DebeziumType: "org.apache.kafka.connect.data.Decimal", - Parameters: map[string]any{ - "scale": "2", - }, - } - assert.Equal(t, expected, converter.ToField("col")) + converter := MoneyConverter{} + expected := transferDbz.Field{ + FieldName: "col", + Type: "bytes", + DebeziumType: "org.apache.kafka.connect.data.Decimal", + Parameters: map[string]any{ + "scale": "2", + }, + } + assert.Equal(t, expected, converter.ToField("col")) } func TestMoneyConverter_Convert(t *testing.T) { - decimalField := NewDecimalConverter(defaultScale, nil).ToField("") - decodeValue := func(value any) string { - bytes, ok := value.([]byte) - assert.True(t, ok) - val, err := decimalField.DecodeDecimal(bytes) - assert.NoError(t, err) - return val.String() - } - { - // Converter where mutateString is true - converter := MoneyConverter{ - StripCommas: true, - CurrencySymbol: "$", - } - { - // string - converted, err := converter.Convert("1234.56") - assert.NoError(t, err) - assert.Equal(t, []byte{0x1, 0xe2, 0x40}, converted) - assert.Equal(t, "1234.56", decodeValue(converted)) - } - { - // string with $ and comma - converted, err := converter.Convert("$1,234.567") - assert.NoError(t, err) - assert.Equal(t, []byte{0x1, 0xe2, 0x40}, converted) - assert.Equal(t, "1234.56", decodeValue(converted)) - } - { - // string with $, comma, and no cents - converted, err := converter.Convert("$1000,234") - assert.NoError(t, err) - assert.Equal(t, []byte{0x5, 0xf6, 0x3c, 0x68}, converted) - assert.Equal(t, "1000234.00", decodeValue(converted)) - } - { - // Malformed string - empty string. - _, err := converter.Convert("") - assert.ErrorContains(t, err, "unable to use '' as a floating-point number") - } - { - // Malformed string - not a floating-point. - _, err := converter.Convert("malformed") - assert.ErrorContains(t, err, "unable to use 'malformed' as a floating-point number") - } - } - { - // Converter where mutateString is false - converter := MoneyConverter{} - { - // int - converted, err := converter.Convert("1234.567") - assert.NoError(t, err) - assert.Equal(t, []byte{0x1, 0xe2, 0x40}, converted) - assert.Equal(t, "1234.56", decodeValue(converted)) - } - } + decimalField := NewDecimalConverter(defaultScale, nil).ToField("") + decodeValue := func(value any) string { + bytes, ok := value.([]byte) + assert.True(t, ok) + val, err := decimalField.DecodeDecimal(bytes) + assert.NoError(t, err) + return val.String() + } + { + // Converter where mutateString is true + converter := MoneyConverter{ + StripCommas: true, + CurrencySymbol: "$", + } + { + // string + converted, err := converter.Convert("1234.56") + assert.NoError(t, err) + assert.Equal(t, []byte{0x1, 0xe2, 0x40}, converted) + assert.Equal(t, "1234.56", decodeValue(converted)) + } + { + // string with $ and comma + converted, err := converter.Convert("$1,234.56") + assert.NoError(t, err) + assert.Equal(t, []byte{0x1, 0xe2, 0x40}, converted) + assert.Equal(t, "1234.56", decodeValue(converted)) + } + { + // string with $, comma, and no cents + converted, err := converter.Convert("$1000,234") + assert.NoError(t, err) + assert.Equal(t, []byte{0x5, 0xf6, 0x3c, 0x68}, converted) + assert.Equal(t, "1000234.00", decodeValue(converted)) + } + { + // Malformed string - empty string. + _, err := converter.Convert("") + assert.ErrorContains(t, err, `unable to use "" as a floating-point number`) + } + { + // Malformed string - not a floating-point. + _, err := converter.Convert("malformed") + assert.ErrorContains(t, err, `unable to use "malformed" as a floating-point number`) + } + } + { + // Converter where mutateString is false + converter := MoneyConverter{} + { + // int + converted, err := converter.Convert("1234.56") + assert.NoError(t, err) + assert.Equal(t, []byte{0x1, 0xe2, 0x40}, converted) + assert.Equal(t, "1234.56", decodeValue(converted)) + } + } } diff --git a/writers/transfer/writer.go b/writers/transfer/writer.go index d125e7bd..0876e60b 100644 --- a/writers/transfer/writer.go +++ b/writers/transfer/writer.go @@ -1,266 +1,266 @@ package transfer import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "time" - - "github.com/artie-labs/transfer/clients/mssql/dialect" - "github.com/artie-labs/transfer/lib/artie" - "github.com/artie-labs/transfer/lib/cdc/mongo" - "github.com/artie-labs/transfer/lib/config" - "github.com/artie-labs/transfer/lib/config/constants" - "github.com/artie-labs/transfer/lib/destination" - "github.com/artie-labs/transfer/lib/destination/utils" - "github.com/artie-labs/transfer/lib/kafkalib" - "github.com/artie-labs/transfer/models" - "github.com/artie-labs/transfer/models/event" - - "github.com/artie-labs/reader/lib" - "github.com/artie-labs/reader/lib/mtr" + "context" + "encoding/json" + "fmt" + "log/slog" + "time" + + "github.com/artie-labs/transfer/clients/mssql/dialect" + "github.com/artie-labs/transfer/lib/artie" + "github.com/artie-labs/transfer/lib/cdc/mongo" + "github.com/artie-labs/transfer/lib/config" + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/destination" + "github.com/artie-labs/transfer/lib/destination/utils" + "github.com/artie-labs/transfer/lib/kafkalib" + "github.com/artie-labs/transfer/models" + "github.com/artie-labs/transfer/models/event" + + "github.com/artie-labs/reader/lib" + "github.com/artie-labs/reader/lib/mtr" ) type Writer struct { - cfg config.Config - statsD mtr.Client - inMemDB *models.DatabaseData - tc *kafkalib.TopicConfig - destination destination.Baseline + cfg config.Config + statsD mtr.Client + inMemDB *models.DatabaseData + tc *kafkalib.TopicConfig + destination destination.Baseline - primaryKeys []string + primaryKeys []string } func NewWriter(cfg config.Config, statsD mtr.Client) (*Writer, error) { - if cfg.Kafka == nil { - return nil, fmt.Errorf("kafka config should not be nil") - } - - if len(cfg.Kafka.TopicConfigs) != 1 { - return nil, fmt.Errorf("kafka config should have exactly one topic config") - } - - writer := &Writer{ - cfg: cfg, - statsD: statsD, - inMemDB: models.NewMemoryDB(), - tc: cfg.Kafka.TopicConfigs[0], - } - - if utils.IsOutputBaseline(cfg) { - baseline, err := utils.LoadBaseline(cfg) - if err != nil { - return nil, err - } - - writer.destination = baseline - } else { - _destination, err := utils.LoadDataWarehouse(cfg, nil) - if err != nil { - return nil, err - } - - writer.destination = _destination - } - - return writer, nil + if cfg.Kafka == nil { + return nil, fmt.Errorf("kafka config should not be nil") + } + + if len(cfg.Kafka.TopicConfigs) != 1 { + return nil, fmt.Errorf("kafka config should have exactly one topic config") + } + + writer := &Writer{ + cfg: cfg, + statsD: statsD, + inMemDB: models.NewMemoryDB(), + tc: cfg.Kafka.TopicConfigs[0], + } + + if utils.IsOutputBaseline(cfg) { + baseline, err := utils.LoadBaseline(cfg) + if err != nil { + return nil, err + } + + writer.destination = baseline + } else { + _destination, err := utils.LoadDataWarehouse(cfg, nil) + if err != nil { + return nil, err + } + + writer.destination = _destination + } + + return writer, nil } func (w *Writer) messageToEvent(message lib.RawMessage) (event.Event, error) { - evt := message.Event() - if mongoEvt, ok := evt.(*mongo.SchemaEventPayload); ok { - bytes, err := json.Marshal(mongoEvt) - if err != nil { - return event.Event{}, err - } - - var dbz mongo.Debezium - evt, err = dbz.GetEventFromBytes(w.cfg.SharedTransferConfig.TypingSettings, bytes) - if err != nil { - return event.Event{}, err - } - - partitionKeyBytes, err := json.Marshal(message.PartitionKey()) - if err != nil { - return event.Event{}, err - } - - partitionKey, err := dbz.GetPrimaryKey(partitionKeyBytes, w.tc) - if err != nil { - return event.Event{}, err - } - - return event.ToMemoryEvent(evt, partitionKey, w.tc, config.Replication) - } - - memoryEvent, err := event.ToMemoryEvent(evt, message.PartitionKey(), w.tc, config.Replication) - if err != nil { - return event.Event{}, err - } - - // Setting the deleted column flag. - memoryEvent.Data[constants.DeleteColumnMarker] = false - return memoryEvent, nil + evt := message.Event() + if mongoEvt, ok := evt.(*mongo.SchemaEventPayload); ok { + bytes, err := json.Marshal(mongoEvt) + if err != nil { + return event.Event{}, err + } + + var dbz mongo.Debezium + evt, err = dbz.GetEventFromBytes(w.cfg.SharedTransferConfig.TypingSettings, bytes) + if err != nil { + return event.Event{}, err + } + + partitionKeyBytes, err := json.Marshal(message.PartitionKey()) + if err != nil { + return event.Event{}, err + } + + partitionKey, err := dbz.GetPrimaryKey(partitionKeyBytes, w.tc) + if err != nil { + return event.Event{}, err + } + + return event.ToMemoryEvent(evt, partitionKey, w.tc, config.Replication) + } + + memoryEvent, err := event.ToMemoryEvent(evt, message.PartitionKey(), w.tc, config.Replication) + if err != nil { + return event.Event{}, err + } + + // Setting the deleted column flag. + memoryEvent.Data[constants.DeleteColumnMarker] = false + return memoryEvent, nil } func (w *Writer) Write(_ context.Context, messages []lib.RawMessage) error { - if len(messages) == 0 { - return nil - } - - var events []event.Event - for _, message := range messages { - evt, err := w.messageToEvent(message) - if err != nil { - return err - } - events = append(events, evt) - } - - tags := map[string]string{ - "mode": w.cfg.Mode.String(), - "op": "r", - "what": "success", - "database": w.tc.Database, - "schema": w.tc.Schema, - "table": events[0].Table, - } - defer func() { - if w.statsD != nil { - w.statsD.Count("process.message", int64(len(events)), tags) - } - }() - - for _, evt := range events { - // Set the primary keys if it's not set already. - if len(w.primaryKeys) == 0 { - var pks []string - for key := range evt.PrimaryKeyMap { - pks = append(pks, key) - } - - w.primaryKeys = pks - } - - shouldFlush, flushReason, err := evt.Save(w.cfg, w.inMemDB, w.tc, artie.Message{}) - if err != nil { - return fmt.Errorf("failed to save event: %w", err) - } - - if shouldFlush { - if err = w.flush(flushReason); err != nil { - return err - } - } - } - - return nil + if len(messages) == 0 { + return nil + } + + var events []event.Event + for _, message := range messages { + evt, err := w.messageToEvent(message) + if err != nil { + return err + } + events = append(events, evt) + } + + tags := map[string]string{ + "mode": w.cfg.Mode.String(), + "op": "r", + "what": "success", + "database": w.tc.Database, + "schema": w.tc.Schema, + "table": events[0].Table, + } + defer func() { + if w.statsD != nil { + w.statsD.Count("process.message", int64(len(events)), tags) + } + }() + + for _, evt := range events { + // Set the primary keys if it's not set already. + if len(w.primaryKeys) == 0 { + var pks []string + for key := range evt.PrimaryKeyMap { + pks = append(pks, key) + } + + w.primaryKeys = pks + } + + shouldFlush, flushReason, err := evt.Save(w.cfg, w.inMemDB, w.tc, artie.Message{}) + if err != nil { + return fmt.Errorf("failed to save event: %w", err) + } + + if shouldFlush { + if err = w.flush(flushReason); err != nil { + return err + } + } + } + + return nil } func (w *Writer) getTableData() (string, *models.TableData, error) { - tableData := w.inMemDB.TableData() - if len(tableData) != 1 { - return "", nil, fmt.Errorf("expected exactly one table") - } - for k, v := range tableData { - return k, v, nil - } - return "", nil, fmt.Errorf("expected exactly one table") + tableData := w.inMemDB.TableData() + if len(tableData) != 1 { + return "", nil, fmt.Errorf("expected exactly one table") + } + for k, v := range tableData { + return k, v, nil + } + return "", nil, fmt.Errorf("expected exactly one table") } func (w *Writer) flush(reason string) error { - tableName, tableData, err := w.getTableData() - if err != nil { - return err - } - - if tableData.ShouldSkipUpdate() { - return nil // No need to flush. - } - - start := time.Now() - tags := map[string]string{ - "what": "success", - "mode": tableData.Mode().String(), - "table": tableName, - "database": tableData.TopicConfig().Database, - "schema": tableData.TopicConfig().Schema, - "reason": reason, - } - defer func() { - if w.statsD != nil { - w.statsD.Timing("flush", time.Since(start), tags) - } - }() - - tableData.ResetTempTableSuffix() - if isMicrosoftSQLServer(w.destination) { - // Microsoft SQL Server uses MERGE not append - if err = w.destination.Merge(tableData.TableData); err != nil { - tags["what"] = "merge_fail" - tags["retryable"] = fmt.Sprint(w.destination.IsRetryableError(err)) - return fmt.Errorf("failed to merge data to destination: %w", err) - } - } else { - // We should hide this column from getting added - if !tableData.TopicConfig().SoftDelete { - tableData.InMemoryColumns().DeleteColumn(constants.DeleteColumnMarker) - } - - if err = w.destination.Append(tableData.TableData); err != nil { - tags["what"] = "merge_fail" - tags["retryable"] = fmt.Sprint(w.destination.IsRetryableError(err)) - return fmt.Errorf("failed to append data to destination: %w", err) - } - } - - w.inMemDB.ClearTableConfig(tableName) - return nil + tableName, tableData, err := w.getTableData() + if err != nil { + return err + } + + if tableData.ShouldSkipUpdate() { + return nil // No need to flush. + } + + start := time.Now() + tags := map[string]string{ + "what": "success", + "mode": tableData.Mode().String(), + "table": tableName, + "database": tableData.TopicConfig().Database, + "schema": tableData.TopicConfig().Schema, + "reason": reason, + } + defer func() { + if w.statsD != nil { + w.statsD.Timing("flush", time.Since(start), tags) + } + }() + + tableData.ResetTempTableSuffix() + if isMicrosoftSQLServer(w.destination) { + // Microsoft SQL Server uses MERGE not append + if err = w.destination.Merge(tableData.TableData); err != nil { + tags["what"] = "merge_fail" + tags["retryable"] = fmt.Sprint(w.destination.IsRetryableError(err)) + return fmt.Errorf("failed to merge data to destination: %w", err) + } + } else { + // We should hide this column from getting added + if !tableData.TopicConfig().SoftDelete { + tableData.InMemoryColumns().DeleteColumn(constants.DeleteColumnMarker) + } + + if err = w.destination.Append(tableData.TableData); err != nil { + tags["what"] = "merge_fail" + tags["retryable"] = fmt.Sprint(w.destination.IsRetryableError(err)) + return fmt.Errorf("failed to append data to destination: %w", err) + } + } + + w.inMemDB.ClearTableConfig(tableName) + return nil } func (w *Writer) OnComplete() error { - if len(w.primaryKeys) == 0 { - return fmt.Errorf("primary keys not set") - } - - if err := w.flush("complete"); err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - - tableName, _, err := w.getTableData() - if err != nil { - return err - } - - if isMicrosoftSQLServer(w.destination) { - // We don't need to run dedupe because it's just merging. - return nil - } - - slog.Info("Running dedupe...", slog.String("table", tableName)) - tableID := w.destination.IdentifierFor(*w.tc, tableName) - start := time.Now() - - dwh, isOk := w.destination.(destination.DataWarehouse) - if !isOk { - return nil - } - - if err = dwh.Dedupe(tableID, w.primaryKeys, *w.tc); err != nil { - return err - } - - slog.Info("Dedupe complete", slog.String("table", tableName), slog.Duration("duration", time.Since(start))) - return nil + if len(w.primaryKeys) == 0 { + return fmt.Errorf("primary keys not set") + } + + if err := w.flush("complete"); err != nil { + return fmt.Errorf("failed to flush: %w", err) + } + + tableName, _, err := w.getTableData() + if err != nil { + return err + } + + if isMicrosoftSQLServer(w.destination) { + // We don't need to run dedupe because it's just merging. + return nil + } + + slog.Info("Running dedupe...", slog.String("table", tableName)) + tableID := w.destination.IdentifierFor(*w.tc, tableName) + start := time.Now() + + dwh, isOk := w.destination.(destination.DataWarehouse) + if !isOk { + return nil + } + + if err = dwh.Dedupe(tableID, w.primaryKeys, w.tc.IncludeArtieUpdatedAt); err != nil { + return err + } + + slog.Info("Dedupe complete", slog.String("table", tableName), slog.Duration("duration", time.Since(start))) + return nil } func isMicrosoftSQLServer(baseline destination.Baseline) bool { - dwh, isOk := baseline.(destination.DataWarehouse) - if !isOk { - return false - } + dwh, isOk := baseline.(destination.DataWarehouse) + if !isOk { + return false + } - _, isOk = dwh.Dialect().(dialect.MSSQLDialect) - return isOk + _, isOk = dwh.Dialect().(dialect.MSSQLDialect) + return isOk }