Skip to content

Commit

Permalink
[MSSQL] To support date, time, datetime primary keys (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored May 25, 2024
1 parent b9619ba commit 7be4d42
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 50 deletions.
134 changes: 95 additions & 39 deletions lib/mssql/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"slices"
"strings"
"time"

"github.com/artie-labs/transfer/clients/mssql/dialect"

Expand All @@ -16,36 +17,19 @@ import (
"github.com/artie-labs/reader/lib/rdbms/scan"
)

var supportedPrimaryKeyDataType = []schema.DataType{
schema.Bit,
schema.Bytes,
schema.Int16,
schema.Int32,
schema.Int64,
schema.Numeric,
schema.Float,
schema.Money,
schema.Date,
schema.String,
schema.Time,
schema.TimeMicro,
schema.TimeNano,
schema.Datetime2,
schema.Datetime2Micro,
schema.Datetime2Nano,
schema.DatetimeOffset,
}
const (
TimeMicro = "15:04:05.000000"
TimeNano = "15:04:05.000000000"
DateTimeMicro = "2006-01-02 15:04:05.000000"
DateTimeNano = "2006-01-02 15:04:05.000000000"
DateTimeOffset = "2006-01-02 15:04:05.0000000 -07:00"
)

func NewScanner(db *sql.DB, table Table, columns []schema.Column, cfg scan.ScannerConfig) (*scan.Scanner, error) {
for _, key := range table.PrimaryKeys() {
_column, err := column.ByName(columns, key)
if err != nil {
if _, err := column.ByName(columns, key); err != nil {
return nil, fmt.Errorf("missing column with name: %q", key)
}

if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) {
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name)
}
}

primaryKeyBounds, err := table.FetchPrimaryKeysBounds(db)
Expand All @@ -63,27 +47,89 @@ type scanAdapter struct {
columns []schema.Column
}

func (s scanAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value string) (any, error) {
// TODO: Implement Date, Time, Datetime for primary key types.
func (s scanAdapter) ParsePrimaryKeyValueForOverrides(_ string, value string) (any, error) {
// We don't need to cast it at all.
return value, nil
}

// encodePrimaryKeyValue - encodes primary key values based on the column type.
// This is needed because the MSSQL SDK does not support parsing `time.Time`, so we need to do it ourselves.
func (s scanAdapter) encodePrimaryKeyValue(columnName string, value any) (any, error) {
columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName })
if columnIdx < 0 {
return nil, fmt.Errorf("primary key column does not exist: %q", columnName)
}

_column := s.columns[columnIdx]
if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) {
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name)
}

switch _column.Type {
case schema.Bit:
return value == "1", nil
switch _columnType := s.columns[columnIdx].Type; _columnType {
case schema.Time:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(time.TimeOnly), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.TimeMicro:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(TimeMicro), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.TimeNano:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(TimeNano), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.Datetime2:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(time.DateTime), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.Datetime2Micro:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(DateTimeMicro), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.Datetime2Nano:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(DateTimeNano), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.DatetimeOffset:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(DateTimeOffset), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
default:
return value, nil
}
}

func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any) {
func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) {
mssqlDialect := dialect.MSSQLDialect{}
colNames := make([]string, len(s.columns))
for idx, col := range s.columns {
Expand All @@ -93,8 +139,18 @@ func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool
startingValues := make([]any, len(primaryKeys))
endingValues := make([]any, len(primaryKeys))
for i, pk := range primaryKeys {
startingValues[i] = pk.StartingValue
endingValues[i] = pk.EndingValue
pkStartVal, err := s.encodePrimaryKeyValue(pk.Name, pk.StartingValue)
if err != nil {
return "", nil, fmt.Errorf("failed to encode start primary key val: %w", err)
}

pkEndVal, err := s.encodePrimaryKeyValue(pk.Name, pk.EndingValue)
if err != nil {
return "", nil, fmt.Errorf("failed to encode end primary key val: %w", err)
}

startingValues[i] = pkStartVal
endingValues[i] = pkEndVal
}

quotedKeyNames := make([]string, len(primaryKeys))
Expand All @@ -119,7 +175,7 @@ func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool
strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","),
// ORDER BY
strings.Join(quotedKeyNames, ","),
), slices.Concat(startingValues, endingValues)
), slices.Concat(startingValues, endingValues), nil
}

func (s scanAdapter) ParseRow(values []any) error {
Expand Down
165 changes: 165 additions & 0 deletions lib/mssql/scanner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package mssql

import (
"github.com/artie-labs/reader/lib/mssql/schema"
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func TestScanAdapter_EncodePrimaryKeyValue(t *testing.T) {
{
// schema.Time
adapter := scanAdapter{
columns: []schema.Column{
{
Name: "time",
Type: schema.Time,
},
},
}

// Able to use string
val, err := adapter.encodePrimaryKeyValue("time", "12:34:56")
assert.NoError(t, err)
assert.Equal(t, "12:34:56", val)

// Able to use time.Time
td := time.Date(2021, 1, 1, 12, 34, 56, 0, time.UTC)
val, err = adapter.encodePrimaryKeyValue("time", td)
assert.NoError(t, err)
assert.Equal(t, "12:34:56", val)
}
{
// Schema.TimeMicro
adapter := scanAdapter{
columns: []schema.Column{
{
Name: "time_micro",
Type: schema.TimeMicro,
},
},
}

// Able to use string
val, err := adapter.encodePrimaryKeyValue("time_micro", "12:34:56.789012")
assert.NoError(t, err)
assert.Equal(t, "12:34:56.789012", val)

// Able to use time.Time
td := time.Date(2021, 1, 1, 12, 34, 56, 789012000, time.UTC)
val, err = adapter.encodePrimaryKeyValue("time_micro", td)
assert.NoError(t, err)
assert.Equal(t, "12:34:56.789012", val)
}
{
// schema.TimeNano
adapter := scanAdapter{
columns: []schema.Column{
{
Name: "time_nano",
Type: schema.TimeNano,
},
},
}

// Able to use string
val, err := adapter.encodePrimaryKeyValue("time_nano", "12:34:56.789012345")
assert.NoError(t, err)
assert.Equal(t, "12:34:56.789012345", val)

// Able to use time.Time
td := time.Date(2021, 1, 1, 12, 34, 56, 789012345, time.UTC)
val, err = adapter.encodePrimaryKeyValue("time_nano", td)
assert.NoError(t, err)
assert.Equal(t, "12:34:56.789012345", val)
}
{
// schema.Datetime2
adapter := scanAdapter{
columns: []schema.Column{
{
Name: "datetime2",
Type: schema.Datetime2,
},
},
}

// Able to use string
val, err := adapter.encodePrimaryKeyValue("datetime2", "2021-01-01 12:34:56")
assert.NoError(t, err)
assert.Equal(t, "2021-01-01 12:34:56", val)

// Able to use time.Time
td := time.Date(2021, 1, 1, 12, 34, 56, 0, time.UTC)
val, err = adapter.encodePrimaryKeyValue("datetime2", td)
assert.NoError(t, err)
assert.Equal(t, "2021-01-01 12:34:56", val)
}
{
// schema.Datetime2Micro
adapter := scanAdapter{
columns: []schema.Column{
{
Name: "datetime2_micro",
Type: schema.Datetime2Micro,
},
},
}

// Able to use string
val, err := adapter.encodePrimaryKeyValue("datetime2_micro", "2021-01-01 12:34:56.789012")
assert.NoError(t, err)
assert.Equal(t, "2021-01-01 12:34:56.789012", val)

// Able to use time.Time
td := time.Date(2021, 1, 1, 12, 34, 56, 789012000, time.UTC)
val, err = adapter.encodePrimaryKeyValue("datetime2_micro", td)
assert.NoError(t, err)
assert.Equal(t, "2021-01-01 12:34:56.789012", val)
}
{
// schema.Datetime2Nano
adapter := scanAdapter{
columns: []schema.Column{
{
Name: "datetime2_nano",
Type: schema.Datetime2Nano,
},
},
}

// Able to use string
val, err := adapter.encodePrimaryKeyValue("datetime2_nano", "2021-01-01 12:34:56.789012345")
assert.NoError(t, err)
assert.Equal(t, "2021-01-01 12:34:56.789012345", val)

// Able to use time.Time
td := time.Date(2021, 1, 1, 12, 34, 56, 789012345, time.UTC)
val, err = adapter.encodePrimaryKeyValue("datetime2_nano", td)
assert.NoError(t, err)
assert.Equal(t, "2021-01-01 12:34:56.789012345", val)
}
{
// schema.DatetimeOffset
adapter := scanAdapter{
columns: []schema.Column{
{
Name: "datetimeoffset",
Type: schema.DatetimeOffset,
},
},
}

// Able to use string
val, err := adapter.encodePrimaryKeyValue("datetimeoffset", "2021-01-01 12:34:56.7890123 +00:00")
assert.NoError(t, err)
assert.Equal(t, "2021-01-01 12:34:56.7890123 +00:00", val)

// Able to use time.Time
td := time.Date(2021, 1, 1, 12, 34, 56, 789012300, time.UTC)
val, err = adapter.encodePrimaryKeyValue("datetimeoffset", td)
assert.NoError(t, err)
assert.Equal(t, "2021-01-01 12:34:56.7890123 +00:00", val)
}
}
4 changes: 2 additions & 2 deletions lib/mysql/scanner/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (s scanAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value s
}
}

func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any) {
func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) {
colNames := make([]string, len(s.columns))
for idx, col := range s.columns {
colNames[idx] = schema.QuoteIdentifier(col.Name)
Expand Down Expand Up @@ -160,7 +160,7 @@ func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool
strings.Join(quotedKeyNames, ","),
// LIMIT
batchSize,
), slices.Concat(startingValues, endingValues)
), slices.Concat(startingValues, endingValues), nil
}

func (s scanAdapter) ParseRow(values []any) error {
Expand Down
6 changes: 4 additions & 2 deletions lib/mysql/scanner/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,15 @@ func TestScanAdapter_BuildQuery(t *testing.T) {
}
{
// exclusive lower bound
query, parameters := adapter.BuildQuery(keys, false, 12)
query, parameters, err := adapter.BuildQuery(keys, false, 12)
assert.NoError(t, err)
assert.Equal(t, "SELECT `foo`,`bar` FROM `table` WHERE (`foo`) > (?) AND (`foo`) <= (?) ORDER BY `foo` LIMIT 12", query)
assert.Equal(t, []any{"a", "b"}, parameters)
}
{
// inclusive upper and lower bounds
query, parameters := adapter.BuildQuery(keys, true, 12)
query, parameters, err := adapter.BuildQuery(keys, true, 12)
assert.NoError(t, err)
assert.Equal(t, "SELECT `foo`,`bar` FROM `table` WHERE (`foo`) >= (?) AND (`foo`) <= (?) ORDER BY `foo` LIMIT 12", query)
assert.Equal(t, []any{"a", "b"}, parameters)
}
Expand Down
Loading

0 comments on commit 7be4d42

Please sign in to comment.