diff --git a/lib/mssql/scanner.go b/lib/mssql/scanner.go index 68e0211d..7174bb79 100644 --- a/lib/mssql/scanner.go +++ b/lib/mssql/scanner.go @@ -5,6 +5,7 @@ import ( "fmt" "slices" "strings" + "time" "github.com/artie-labs/transfer/clients/mssql/dialect" @@ -16,36 +17,19 @@ import ( "github.com/artie-labs/reader/lib/rdbms/scan" ) -var supportedPrimaryKeyDataType = []schema.DataType{ - schema.Bit, - schema.Bytes, - schema.Int16, - schema.Int32, - schema.Int64, - schema.Numeric, - schema.Float, - schema.Money, - schema.Date, - schema.String, - schema.Time, - schema.TimeMicro, - schema.TimeNano, - schema.Datetime2, - schema.Datetime2Micro, - schema.Datetime2Nano, - schema.DatetimeOffset, -} +const ( + TimeMicro = "15:04:05.000000" + TimeNano = "15:04:05.000000000" + DateTimeMicro = "2006-01-02 15:04:05.000000" + DateTimeNano = "2006-01-02 15:04:05.000000000" + DateTimeOffset = "2006-01-02 15:04:05.0000000 -07:00" +) func NewScanner(db *sql.DB, table Table, columns []schema.Column, cfg scan.ScannerConfig) (*scan.Scanner, error) { for _, key := range table.PrimaryKeys() { - _column, err := column.ByName(columns, key) - if err != nil { + if _, err := column.ByName(columns, key); err != nil { return nil, fmt.Errorf("missing column with name: %q", key) } - - if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { - return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) - } } primaryKeyBounds, err := table.FetchPrimaryKeysBounds(db) @@ -63,27 +47,89 @@ type scanAdapter struct { columns []schema.Column } -func (s scanAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value string) (any, error) { - // TODO: Implement Date, Time, Datetime for primary key types. +func (s scanAdapter) ParsePrimaryKeyValueForOverrides(_ string, value string) (any, error) { + // We don't need to cast it at all. + return value, nil +} + +// encodePrimaryKeyValue - encodes primary key values based on the column type. +// This is needed because the MSSQL SDK does not support parsing `time.Time`, so we need to do it ourselves. +func (s scanAdapter) encodePrimaryKeyValue(columnName string, value any) (any, error) { columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName }) if columnIdx < 0 { return nil, fmt.Errorf("primary key column does not exist: %q", columnName) } - _column := s.columns[columnIdx] - if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { - return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) - } - - switch _column.Type { - case schema.Bit: - return value == "1", nil + switch _columnType := s.columns[columnIdx].Type; _columnType { + case schema.Time: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(time.TimeOnly), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.TimeMicro: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(TimeMicro), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.TimeNano: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(TimeNano), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.Datetime2: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(time.DateTime), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.Datetime2Micro: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(DateTimeMicro), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.Datetime2Nano: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(DateTimeNano), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.DatetimeOffset: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(DateTimeOffset), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } default: return value, nil } } -func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any) { +func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) { mssqlDialect := dialect.MSSQLDialect{} colNames := make([]string, len(s.columns)) for idx, col := range s.columns { @@ -93,8 +139,18 @@ func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool startingValues := make([]any, len(primaryKeys)) endingValues := make([]any, len(primaryKeys)) for i, pk := range primaryKeys { - startingValues[i] = pk.StartingValue - endingValues[i] = pk.EndingValue + pkStartVal, err := s.encodePrimaryKeyValue(pk.Name, pk.StartingValue) + if err != nil { + return "", nil, fmt.Errorf("failed to encode start primary key val: %w", err) + } + + pkEndVal, err := s.encodePrimaryKeyValue(pk.Name, pk.EndingValue) + if err != nil { + return "", nil, fmt.Errorf("failed to encode end primary key val: %w", err) + } + + startingValues[i] = pkStartVal + endingValues[i] = pkEndVal } quotedKeyNames := make([]string, len(primaryKeys)) @@ -119,7 +175,7 @@ func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","), // ORDER BY strings.Join(quotedKeyNames, ","), - ), slices.Concat(startingValues, endingValues) + ), slices.Concat(startingValues, endingValues), nil } func (s scanAdapter) ParseRow(values []any) error { diff --git a/lib/mssql/scanner_test.go b/lib/mssql/scanner_test.go new file mode 100644 index 00000000..3b3861e3 --- /dev/null +++ b/lib/mssql/scanner_test.go @@ -0,0 +1,165 @@ +package mssql + +import ( + "github.com/artie-labs/reader/lib/mssql/schema" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestScanAdapter_EncodePrimaryKeyValue(t *testing.T) { + { + // schema.Time + adapter := scanAdapter{ + columns: []schema.Column{ + { + Name: "time", + Type: schema.Time, + }, + }, + } + + // Able to use string + val, err := adapter.encodePrimaryKeyValue("time", "12:34:56") + assert.NoError(t, err) + assert.Equal(t, "12:34:56", val) + + // Able to use time.Time + td := time.Date(2021, 1, 1, 12, 34, 56, 0, time.UTC) + val, err = adapter.encodePrimaryKeyValue("time", td) + assert.NoError(t, err) + assert.Equal(t, "12:34:56", val) + } + { + // Schema.TimeMicro + adapter := scanAdapter{ + columns: []schema.Column{ + { + Name: "time_micro", + Type: schema.TimeMicro, + }, + }, + } + + // Able to use string + val, err := adapter.encodePrimaryKeyValue("time_micro", "12:34:56.789012") + assert.NoError(t, err) + assert.Equal(t, "12:34:56.789012", val) + + // Able to use time.Time + td := time.Date(2021, 1, 1, 12, 34, 56, 789012000, time.UTC) + val, err = adapter.encodePrimaryKeyValue("time_micro", td) + assert.NoError(t, err) + assert.Equal(t, "12:34:56.789012", val) + } + { + // schema.TimeNano + adapter := scanAdapter{ + columns: []schema.Column{ + { + Name: "time_nano", + Type: schema.TimeNano, + }, + }, + } + + // Able to use string + val, err := adapter.encodePrimaryKeyValue("time_nano", "12:34:56.789012345") + assert.NoError(t, err) + assert.Equal(t, "12:34:56.789012345", val) + + // Able to use time.Time + td := time.Date(2021, 1, 1, 12, 34, 56, 789012345, time.UTC) + val, err = adapter.encodePrimaryKeyValue("time_nano", td) + assert.NoError(t, err) + assert.Equal(t, "12:34:56.789012345", val) + } + { + // schema.Datetime2 + adapter := scanAdapter{ + columns: []schema.Column{ + { + Name: "datetime2", + Type: schema.Datetime2, + }, + }, + } + + // Able to use string + val, err := adapter.encodePrimaryKeyValue("datetime2", "2021-01-01 12:34:56") + assert.NoError(t, err) + assert.Equal(t, "2021-01-01 12:34:56", val) + + // Able to use time.Time + td := time.Date(2021, 1, 1, 12, 34, 56, 0, time.UTC) + val, err = adapter.encodePrimaryKeyValue("datetime2", td) + assert.NoError(t, err) + assert.Equal(t, "2021-01-01 12:34:56", val) + } + { + // schema.Datetime2Micro + adapter := scanAdapter{ + columns: []schema.Column{ + { + Name: "datetime2_micro", + Type: schema.Datetime2Micro, + }, + }, + } + + // Able to use string + val, err := adapter.encodePrimaryKeyValue("datetime2_micro", "2021-01-01 12:34:56.789012") + assert.NoError(t, err) + assert.Equal(t, "2021-01-01 12:34:56.789012", val) + + // Able to use time.Time + td := time.Date(2021, 1, 1, 12, 34, 56, 789012000, time.UTC) + val, err = adapter.encodePrimaryKeyValue("datetime2_micro", td) + assert.NoError(t, err) + assert.Equal(t, "2021-01-01 12:34:56.789012", val) + } + { + // schema.Datetime2Nano + adapter := scanAdapter{ + columns: []schema.Column{ + { + Name: "datetime2_nano", + Type: schema.Datetime2Nano, + }, + }, + } + + // Able to use string + val, err := adapter.encodePrimaryKeyValue("datetime2_nano", "2021-01-01 12:34:56.789012345") + assert.NoError(t, err) + assert.Equal(t, "2021-01-01 12:34:56.789012345", val) + + // Able to use time.Time + td := time.Date(2021, 1, 1, 12, 34, 56, 789012345, time.UTC) + val, err = adapter.encodePrimaryKeyValue("datetime2_nano", td) + assert.NoError(t, err) + assert.Equal(t, "2021-01-01 12:34:56.789012345", val) + } + { + // schema.DatetimeOffset + adapter := scanAdapter{ + columns: []schema.Column{ + { + Name: "datetimeoffset", + Type: schema.DatetimeOffset, + }, + }, + } + + // Able to use string + val, err := adapter.encodePrimaryKeyValue("datetimeoffset", "2021-01-01 12:34:56.7890123 +00:00") + assert.NoError(t, err) + assert.Equal(t, "2021-01-01 12:34:56.7890123 +00:00", val) + + // Able to use time.Time + td := time.Date(2021, 1, 1, 12, 34, 56, 789012300, time.UTC) + val, err = adapter.encodePrimaryKeyValue("datetimeoffset", td) + assert.NoError(t, err) + assert.Equal(t, "2021-01-01 12:34:56.7890123 +00:00", val) + } +} diff --git a/lib/mysql/scanner/scanner.go b/lib/mysql/scanner/scanner.go index 45d95975..d9705227 100644 --- a/lib/mysql/scanner/scanner.go +++ b/lib/mysql/scanner/scanner.go @@ -125,7 +125,7 @@ func (s scanAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value s } } -func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any) { +func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) { colNames := make([]string, len(s.columns)) for idx, col := range s.columns { colNames[idx] = schema.QuoteIdentifier(col.Name) @@ -160,7 +160,7 @@ func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool strings.Join(quotedKeyNames, ","), // LIMIT batchSize, - ), slices.Concat(startingValues, endingValues) + ), slices.Concat(startingValues, endingValues), nil } func (s scanAdapter) ParseRow(values []any) error { diff --git a/lib/mysql/scanner/scanner_test.go b/lib/mysql/scanner/scanner_test.go index 175cc7d0..a5965b39 100644 --- a/lib/mysql/scanner/scanner_test.go +++ b/lib/mysql/scanner/scanner_test.go @@ -243,13 +243,15 @@ func TestScanAdapter_BuildQuery(t *testing.T) { } { // exclusive lower bound - query, parameters := adapter.BuildQuery(keys, false, 12) + query, parameters, err := adapter.BuildQuery(keys, false, 12) + assert.NoError(t, err) assert.Equal(t, "SELECT `foo`,`bar` FROM `table` WHERE (`foo`) > (?) AND (`foo`) <= (?) ORDER BY `foo` LIMIT 12", query) assert.Equal(t, []any{"a", "b"}, parameters) } { // inclusive upper and lower bounds - query, parameters := adapter.BuildQuery(keys, true, 12) + query, parameters, err := adapter.BuildQuery(keys, true, 12) + assert.NoError(t, err) assert.Equal(t, "SELECT `foo`,`bar` FROM `table` WHERE (`foo`) >= (?) AND (`foo`) <= (?) ORDER BY `foo` LIMIT 12", query) assert.Equal(t, []any{"a", "b"}, parameters) } diff --git a/lib/postgres/scanner.go b/lib/postgres/scanner.go index 22154a75..f7c5a11f 100644 --- a/lib/postgres/scanner.go +++ b/lib/postgres/scanner.go @@ -162,7 +162,7 @@ func queryPlaceholders(offset, count int) []string { return result } -func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any) { +func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) { castedColumns := make([]string, len(s.columns)) for i, col := range s.columns { castedColumns[i] = castColumn(col) @@ -198,7 +198,7 @@ func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool strings.Join(quotedKeyNames, ","), // LIMIT batchSize, - ), slices.Concat(startingValues, endingValues) + ), slices.Concat(startingValues, endingValues), nil } func (s scanAdapter) ParseRow(values []any) error { diff --git a/lib/postgres/scanner_test.go b/lib/postgres/scanner_test.go index 2e4e1dc2..2d9fb4fc 100644 --- a/lib/postgres/scanner_test.go +++ b/lib/postgres/scanner_test.go @@ -65,13 +65,15 @@ func TestScanAdapter_BuildQuery(t *testing.T) { { // inclusive lower bound - query, parameters := adapter.BuildQuery(primaryKeys, true, 1) + query, parameters, err := adapter.BuildQuery(primaryKeys, true, 1) + assert.NoError(t, err) assert.Equal(t, `SELECT "a","b","c","e","f",ARRAY_TO_JSON("g")::TEXT as "g" FROM "schema"."table" WHERE row("a","b","c") >= row($1,$2,$3) AND row("a","b","c") <= row($4,$5,$6) ORDER BY "a","b","c" LIMIT 1`, query) assert.Equal(t, []any{int64(1), int64(2), "3", int64(4), int64(5), "6"}, parameters) } { // exclusive lower bound - query, parameters := adapter.BuildQuery(primaryKeys, false, 2) + query, parameters, err := adapter.BuildQuery(primaryKeys, false, 2) + assert.NoError(t, err) assert.Equal(t, `SELECT "a","b","c","e","f",ARRAY_TO_JSON("g")::TEXT as "g" FROM "schema"."table" WHERE row("a","b","c") > row($1,$2,$3) AND row("a","b","c") <= row($4,$5,$6) ORDER BY "a","b","c" LIMIT 2`, query) assert.Equal(t, []any{int64(1), int64(2), "3", int64(4), int64(5), "6"}, parameters) } diff --git a/lib/rdbms/scan/scan.go b/lib/rdbms/scan/scan.go index b1f9a8ab..7edff985 100644 --- a/lib/rdbms/scan/scan.go +++ b/lib/rdbms/scan/scan.go @@ -23,7 +23,7 @@ type ScannerConfig struct { type ScanAdapter interface { ParsePrimaryKeyValueForOverrides(columnName string, value string) (any, error) - BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any) + BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) ParseRow(row []any) error } @@ -116,7 +116,11 @@ func (s *Scanner) Next() ([]map[string]any, error) { } func (s *Scanner) scan() ([]map[string]any, error) { - query, parameters := s.adapter.BuildQuery(s.primaryKeys.Keys(), s.isFirstBatch, s.batchSize) + query, parameters, err := s.adapter.BuildQuery(s.primaryKeys.Keys(), s.isFirstBatch, s.batchSize) + if err != nil { + return nil, fmt.Errorf("failed to build query: %w", err) + } + slog.Info("Scan query", slog.String("query", query), slog.Any("parameters", parameters)) rows, err := retry.WithRetriesAndResult(s.retryCfg, func(_ int, _ error) (*sql.Rows, error) { diff --git a/lib/rdbms/scan/scan_test.go b/lib/rdbms/scan/scan_test.go index 490256eb..7daa9597 100644 --- a/lib/rdbms/scan/scan_test.go +++ b/lib/rdbms/scan/scan_test.go @@ -20,7 +20,7 @@ func (m mockAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value s } } -func (mockAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any) { +func (mockAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) { panic("not implemented") }