Skip to content

Commit

Permalink
Support Microsoft SQL Server - The Finale (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored May 18, 2024
1 parent 5fe95bb commit ca97434
Show file tree
Hide file tree
Showing 20 changed files with 1,130 additions and 35 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Artie Reader reads from databases to perform historical snapshots and also reads
| MongoDB |||
| MySQL |||
| PostgreSQL |||
| SQL Server |||


## Running
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/jackc/pgx/v5 v5.5.5
github.com/lmittmann/tint v1.0.4
github.com/mattn/go-isatty v0.0.20
github.com/microsoft/go-mssqldb v1.7.1
github.com/samber/slog-multi v1.0.2
github.com/samber/slog-sentry/v2 v2.5.0
github.com/segmentio/kafka-go v0.4.47
Expand Down Expand Up @@ -96,7 +97,6 @@ require (
github.com/lestrrat-go/jwx v1.2.29 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/microsoft/go-mssqldb v1.7.0 // indirect
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4=
github.com/microsoft/go-mssqldb v1.7.0 h1:sgMPW0HA6Ihd37Yx0MzHyKD726C2kY/8KJsQtXHNaAs=
github.com/microsoft/go-mssqldb v1.7.0/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
github.com/microsoft/go-mssqldb v1.7.1 h1:KU/g8aWeM3Hx7IMOFpiwYiUkU+9zeISb4+tx3ScVfsM=
github.com/microsoft/go-mssqldb v1.7.1/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
Expand Down
75 changes: 75 additions & 0 deletions lib/mssql/parse/parse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package parse

import (
"fmt"
"time"

mssql "github.com/microsoft/go-mssqldb"

"github.com/artie-labs/reader/lib/mssql/schema"
)

func ParseValue(colKind schema.DataType, value any) (any, error) {
if value == nil {
return nil, nil
}

switch colKind {
case schema.Bit:
if _, isOk := value.(bool); !isOk {
return nil, fmt.Errorf("expected bool got %T with value: %v", value, value)
}

return value, nil
case schema.Bytes:
if _, isOk := value.([]byte); !isOk {
return nil, fmt.Errorf("expected []byte got %T with value: %v", value, value)
}

return value, nil
case schema.Int16, schema.Int32, schema.Int64:
if _, isOk := value.(int64); !isOk {
return nil, fmt.Errorf("expected int64 got %T with value: %v", value, value)
}

return value, nil
case schema.Numeric, schema.Money:
val, isOk := value.([]byte)
if !isOk {
return nil, fmt.Errorf("expected []byte got %T with value: %v", value, value)
}

return string(val), nil
case schema.Float:
if _, isOk := value.(float64); !isOk {
return nil, fmt.Errorf("expected float64 got %T with value: %v", value, value)
}

return value, nil
case schema.String:
if _, isOk := value.(string); !isOk {
return nil, fmt.Errorf("expected string got %T with value: %v", value, value)
}

return value, nil
case schema.UniqueIdentifier:
var uniq mssql.UniqueIdentifier
if err := uniq.Scan(value); err != nil {
return nil, fmt.Errorf("failed to parse unique identifier value %q: %w", value, err)
}

return uniq.String(), nil
case
schema.Date,
schema.Time, schema.TimeMicro, schema.TimeNano,
schema.Datetime2, schema.Datetime2Micro, schema.Datetime2Nano,
schema.DatetimeOffset:
if _, isOk := value.(time.Time); !isOk {
return nil, fmt.Errorf("expected time.Time got %T with value: %v", value, value)
}

return value, nil
}

return value, nil
}
89 changes: 89 additions & 0 deletions lib/mssql/parse/parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package parse

import (
"testing"
"time"

"github.com/artie-labs/reader/lib/mssql/schema"
"github.com/stretchr/testify/assert"
)

func TestParseValue(t *testing.T) {
{
// Bit
value, err := ParseValue(schema.Bit, true)
assert.NoError(t, err)
assert.Equal(t, true, value)

_, err = ParseValue(schema.Bit, 1234)
assert.ErrorContains(t, err, "expected bool got int with value: 1234")
}
{
// Bytes
value, err := ParseValue(schema.Bytes, []byte("test"))
assert.NoError(t, err)
assert.Equal(t, []byte("test"), value)

_, err = ParseValue(schema.Bytes, 1234)
assert.ErrorContains(t, err, "expected []byte got int with value: 1234")
}
{
for _, schemaDataType := range []schema.DataType{schema.Int16, schema.Int32, schema.Int64} {
// Int16, Int32, Int64
value, err := ParseValue(schemaDataType, int64(1234))
assert.NoError(t, err, schemaDataType)
assert.Equal(t, int64(1234), value, schemaDataType)

_, err = ParseValue(schemaDataType, 1234)
assert.ErrorContains(t, err, "expected int64 got int with value: 1234", schemaDataType)
}
}
{
// Numeric
value, err := ParseValue(schema.Numeric, []uint8("1234"))
assert.NoError(t, err)
assert.Equal(t, "1234", value)
}
{
// Floats
value, err := ParseValue(schema.Float, float64(1234))
assert.NoError(t, err)
assert.Equal(t, float64(1234), value)
}
{
// Money
value, err := ParseValue(schema.Money, []uint8("1234"))
assert.NoError(t, err)
assert.Equal(t, "1234", value)
}
{
// String
value, err := ParseValue(schema.String, "test")
assert.NoError(t, err)
assert.Equal(t, "test", value)
}
{
// UniqueIdentifier
value, err := ParseValue(schema.UniqueIdentifier, []byte{246, 152, 170, 145, 154, 66, 152, 64, 138, 219, 20, 190, 130, 229, 187, 126})
assert.NoError(t, err)
assert.Equal(t, "91AA98F6-429A-4098-8ADB-14BE82E5BB7E", value)
}
{
// Date, Time, TimeMicro, TimeNano, Datetime2, Datetime2Micro, Datetime2Nano, DatetimeOffset
schemaDataTypes := []schema.DataType{
schema.Date,
schema.Time, schema.TimeMicro, schema.TimeNano,
schema.Datetime2, schema.Datetime2Micro, schema.Datetime2Nano,
schema.DatetimeOffset,
}

for _, schemaDataType := range schemaDataTypes {
value, err := ParseValue(schemaDataType, time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC))
assert.NoError(t, err, schemaDataType)
assert.IsType(t, time.Time{}, value, schemaDataType)

_, err = ParseValue(schemaDataType, 1234)
assert.ErrorContains(t, err, "expected time.Time got int with value: 1234", schemaDataType)
}
}
}
135 changes: 135 additions & 0 deletions lib/mssql/scanner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package mssql

import (
"database/sql"
"fmt"
"github.com/artie-labs/transfer/clients/mssql/dialect"
"slices"
"strings"

"github.com/artie-labs/reader/lib/mssql/parse"
"github.com/artie-labs/reader/lib/mssql/schema"
"github.com/artie-labs/reader/lib/rdbms"
"github.com/artie-labs/reader/lib/rdbms/column"
"github.com/artie-labs/reader/lib/rdbms/primary_key"
"github.com/artie-labs/reader/lib/rdbms/scan"
)

var supportedPrimaryKeyDataType = []schema.DataType{
schema.Bit,
schema.Bytes,
schema.Int16,
schema.Int32,
schema.Int64,
schema.Numeric,
schema.Float,
schema.Money,
schema.Date,
schema.String,
schema.Time,
schema.TimeMicro,
schema.TimeNano,
schema.Datetime2,
schema.Datetime2Micro,
schema.Datetime2Nano,
schema.DatetimeOffset,
}

func NewScanner(db *sql.DB, table Table, columns []schema.Column, cfg scan.ScannerConfig) (*scan.Scanner, error) {
for _, key := range table.PrimaryKeys() {
_column, err := column.GetColumnByName(columns, key)
if err != nil {
return nil, fmt.Errorf("missing column with name: %q", key)
}

if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) {
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name)
}
}

primaryKeyBounds, err := table.GetPrimaryKeysBounds(db)
if err != nil {
return nil, err
}

adapter := scanAdapter{schema: table.Schema, tableName: table.Name, columns: columns}
return scan.NewScanner(db, primaryKeyBounds, cfg, adapter)
}

type scanAdapter struct {
schema string
tableName string
columns []schema.Column
}

func (s scanAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value string) (any, error) {
// TODO: Implement Date, Time, Datetime for primary key types.
columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName })
if columnIdx < 0 {
return nil, fmt.Errorf("primary key column does not exist: %q", columnName)
}

_column := s.columns[columnIdx]
if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) {
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name)
}

switch _column.Type {
case schema.Bit:
return value == "1", nil
default:
return value, nil
}
}

func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any) {
mssqlDialect := dialect.MSSQLDialect{}
colNames := make([]string, len(s.columns))
for idx, col := range s.columns {
colNames[idx] = mssqlDialect.QuoteIdentifier(col.Name)
}

startingValues := make([]any, len(primaryKeys))
endingValues := make([]any, len(primaryKeys))
for i, pk := range primaryKeys {
startingValues[i] = pk.StartingValue
endingValues[i] = pk.EndingValue
}

quotedKeyNames := make([]string, len(primaryKeys))
for i, key := range primaryKeys {
quotedKeyNames[i] = mssqlDialect.QuoteIdentifier(key.Name)
}

lowerBoundComparison := ">"
if isFirstBatch {
lowerBoundComparison = ">="
}

return fmt.Sprintf(`SELECT TOP %d %s FROM %s.%s WHERE (%s) %s (%s) AND (%s) <= (%s) ORDER BY %s`,
// TOP
batchSize,
// SELECT
strings.Join(colNames, ","),
// FROM
mssqlDialect.QuoteIdentifier(s.schema), mssqlDialect.QuoteIdentifier(s.tableName),
// WHERE (pk) > (123)
strings.Join(quotedKeyNames, ","), lowerBoundComparison, strings.Join(rdbms.QueryPlaceholders("?", len(startingValues)), ","),
strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","),
// ORDER BY
strings.Join(quotedKeyNames, ","),
), slices.Concat(startingValues, endingValues)
}

func (s scanAdapter) ParseRow(values []any) error {
for i, value := range values {
parsedValue, err := parse.ParseValue(s.columns[i].Type, value)
if err != nil {
return fmt.Errorf("failed to parse column: %q: %w", s.columns[i].Name, err)
}

values[i] = parsedValue
}

return nil
}
Loading

0 comments on commit ca97434

Please sign in to comment.