Skip to content

Commit

Permalink
[postgres] Stricter column type handling (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Mar 7, 2024
1 parent e98415e commit 3953d64
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 107 deletions.
4 changes: 1 addition & 3 deletions lib/postgres/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
func castColumn(col schema.Column) (string, error) {
colName := pgx.Identifier{col.Name}.Sanitize()
switch col.Type {
case schema.InvalidDataType:
return colName, nil
case schema.Inet:
return fmt.Sprintf("%s::text", colName), nil
case schema.Time, schema.Interval:
Expand All @@ -38,6 +36,6 @@ func castColumn(col schema.Column) (string, error) {
// These are all the columns that do not need to be escaped.
return colName, nil
default:
return "", fmt.Errorf("unsupported column type DataType(%d)", col.Type)
return "", fmt.Errorf("unsupported column type: DataType(%d)", col.Type)
}
}
16 changes: 13 additions & 3 deletions lib/postgres/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ func TestCastColumn(t *testing.T) {
name string
dataType schema.DataType

expected string
expected string
expectedErr string
}

var testCases = []_testCase{
Expand Down Expand Up @@ -67,11 +68,20 @@ func TestCastColumn(t *testing.T) {
dataType: schema.VariableNumeric,
expected: `"foo"`,
},
{
name: "unsupported",
dataType: -1,
expectedErr: "unsupported column type: DataType(-1)",
},
}

for _, testCase := range testCases {
actualEscCol, err := castColumn(schema.Column{Name: "foo", Type: testCase.dataType})
assert.NoError(t, err)
assert.Equal(t, testCase.expected, actualEscCol, testCase.name)
if testCase.expectedErr == "" {
assert.NoError(t, err, testCase.name)
assert.Equal(t, testCase.expected, actualEscCol, testCase.name)
} else {
assert.ErrorContains(t, err, testCase.expectedErr, testCase.name)
}
}
}
8 changes: 3 additions & 5 deletions lib/postgres/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ func scanTableQuery(args scanTableQueryArgs) (string, error) {

func shouldQuoteValue(dataType schema.DataType) (bool, error) {
switch dataType {
case schema.InvalidDataType:
return false, fmt.Errorf("invalid data type")
case
schema.Bit, // Fails: operator does not exist: bit >= boolean (SQLSTATE 42883)
schema.Time, // Fails: invalid input syntax for type time: "45296000" (SQLSTATE 22007)
Expand Down Expand Up @@ -147,18 +145,18 @@ func convertToStringForQuery(value any, dataType schema.DataType) (string, error
return fmt.Sprint(value), nil
default:
slog.Error("bool value with non-bool column type",
slog.Any("value", value),
slog.Bool("value", castValue),
slog.Any("dataType", dataType),
)
}
case string:
switch dataType {
case schema.Text, schema.UserDefinedText, schema.Inet, schema.UUID, schema.JSON, schema.VariableNumeric,
schema.Numeric, schema.Money:
return QuoteLiteral(fmt.Sprint(value)), nil
return QuoteLiteral(castValue), nil
default:
slog.Error("string value with non-string column type",
slog.Any("value", value),
slog.String("value", castValue),
slog.Any("dataType", dataType),
)
}
Expand Down
10 changes: 2 additions & 8 deletions lib/postgres/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ func TestShouldQuoteValue(t *testing.T) {
}
}

_, err := shouldQuoteValue(schema.InvalidDataType)
assert.ErrorContains(t, err, "invalid data type")
_, err := shouldQuoteValue(-1)
assert.ErrorContains(t, err, "unsupported data type: DataType(-1)")
}

func TestConvertToStringForQuery(t *testing.T) {
Expand Down Expand Up @@ -107,12 +107,6 @@ func TestConvertToStringForQuery(t *testing.T) {
dataType: schema.Text,
expected: "'foo'",
},
{
name: "text - invalid data type",
value: "foo",
dataType: schema.InvalidDataType,
expectedErr: "invalid data type",
},
{
name: "text - unsupported data type",
value: "foo",
Expand Down
57 changes: 28 additions & 29 deletions lib/postgres/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ import (
type DataType int

const (
InvalidDataType DataType = iota
VariableNumeric
VariableNumeric DataType = iota
Money
Numeric
Bit
Expand Down Expand Up @@ -78,8 +77,8 @@ func DescribeTable(db *sql.DB, _schema, table string) ([]Column, error) {
return nil, err
}

dataType, opts := ParseColumnDataType(colType, numericPrecision, numericScale, udtName)
if dataType == InvalidDataType {
dataType, opts, err := ParseColumnDataType(colType, numericPrecision, numericScale, udtName)
if err != nil {
return nil, fmt.Errorf("unable to identify type for column %s: %s", colName, colType)
}

Expand All @@ -92,70 +91,70 @@ func DescribeTable(db *sql.DB, _schema, table string) ([]Column, error) {
return cols, nil
}

func ParseColumnDataType(colKind string, precision, scale, udtName *string) (DataType, *Opts) {
func ParseColumnDataType(colKind string, precision, scale, udtName *string) (DataType, *Opts, error) {
colKind = strings.ToLower(colKind)
switch colKind {
case "point":
return Point, nil
return Point, nil, nil
case "real", "double precision":
return Float, nil
return Float, nil, nil
case "smallint":
return Int16, nil
return Int16, nil, nil
case "integer":
return Int32, nil
return Int32, nil, nil
case "bigint", "oid":
return Int64, nil
return Int64, nil, nil
case "array":
return Array, nil
return Array, nil, nil
case "bit":
return Bit, nil
return Bit, nil, nil
case "boolean":
return Boolean, nil
return Boolean, nil, nil
case "date":
return Date, nil
return Date, nil, nil
case "uuid":
return UUID, nil
return UUID, nil, nil
case "user-defined":
if udtName != nil && *udtName == "hstore" {
return HStore, nil
return HStore, nil, nil
} else if udtName != nil && *udtName == "geometry" {
return Geometry, nil
return Geometry, nil, nil
} else if udtName != nil && *udtName == "geography" {
return Geography, nil
return Geography, nil, nil
} else {
return UserDefinedText, nil
return UserDefinedText, nil, nil
}
case "interval":
return Interval, nil
return Interval, nil, nil
case "time with time zone", "time without time zone":
return Time, nil
return Time, nil, nil
case "money":
return Money, &Opts{
Scale: ptr.ToString("2"),
}
}, nil
case "character varying", "text", "character", "xml", "cidr", "macaddr", "macaddr8",
"int4range", "int8range", "numrange", "daterange", "tsrange", "tstzrange":
return Text, nil
return Text, nil, nil
case "inet":
return Inet, nil
return Inet, nil, nil
case "json", "jsonb":
return JSON, nil
return JSON, nil, nil
case "timestamp without time zone", "timestamp with time zone":
return Timestamp, nil
return Timestamp, nil, nil
default:
if strings.Contains(colKind, "numeric") {
if precision == nil && scale == nil {
return VariableNumeric, nil
return VariableNumeric, nil, nil
} else {
return Numeric, &Opts{
Scale: scale,
Precision: precision,
}
}, nil
}
}
}

return InvalidDataType, nil
return -1, nil, fmt.Errorf("unknown data type: %s", colKind)
}

// This is a fork of: https://wiki.postgresql.org/wiki/Retrieve_primary_key_columns
Expand Down
17 changes: 14 additions & 3 deletions lib/postgres/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func TestParseColumnDataType(t *testing.T) {

expectedDataType DataType
expectedOpts *Opts
expectedErr string
}

var testCases = []_testCase{
Expand Down Expand Up @@ -124,12 +125,22 @@ func TestParseColumnDataType(t *testing.T) {
udtName: ptr.ToString("foo"),
expectedDataType: UserDefinedText,
},
{
name: "unsupported",
colKind: "foo",
expectedErr: "unknown data type: foo",
},
}

for _, testCase := range testCases {
dataType, opts := ParseColumnDataType(testCase.colKind, testCase.precision, testCase.scale, testCase.udtName)
assert.Equal(t, testCase.expectedDataType, dataType, testCase.name)
assert.Equal(t, testCase.expectedOpts, opts, testCase.name)
dataType, opts, err := ParseColumnDataType(testCase.colKind, testCase.precision, testCase.scale, testCase.udtName)
if testCase.expectedErr == "" {
assert.NoError(t, err, testCase.name)
assert.Equal(t, testCase.expectedDataType, dataType, testCase.name)
assert.Equal(t, testCase.expectedOpts, opts, testCase.name)
} else {
assert.ErrorContains(t, err, testCase.expectedErr, testCase.name)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion sources/mysql/adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func newMySQLAdapter(db *sql.DB, table mysql.Table, scannerCfg scan.ScannerConfi
for i, col := range table.Columns {
converter, err := valueConverterForType(col.Type, col.Opts)
if err != nil {
return mysqlAdapter{}, err
return mysqlAdapter{}, fmt.Errorf("failed to build field for column %s: %w", col.Name, err)
}
fields[i] = converter.ToField(col.Name)
valueConverters[col.Name] = converter
Expand Down
16 changes: 11 additions & 5 deletions sources/postgres/adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const defaultErrorRetries = 10
type postgresAdapter struct {
db *sql.DB
table postgres.Table
fields []transferDbz.Field
scannerCfg scan.ScannerConfig
}

Expand All @@ -29,9 +30,18 @@ func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (postgresAd
return postgresAdapter{}, fmt.Errorf("failed to load metadata for table %s.%s: %w", tableCfg.Schema, tableCfg.Name, err)
}

fields := make([]transferDbz.Field, len(table.Columns))
for i, col := range table.Columns {
fields[i], err = ColumnToField(col)
if err != nil {
return postgresAdapter{}, fmt.Errorf("failed to build field for column %s: %w", col.Name, err)
}
}

return postgresAdapter{
db: db,
table: *table,
fields: fields,
scannerCfg: tableCfg.ToScannerConfig(defaultErrorRetries),
}, nil
}
Expand All @@ -45,11 +55,7 @@ func (p postgresAdapter) TopicSuffix() string {
}

func (p postgresAdapter) Fields() []transferDbz.Field {
fields := make([]transferDbz.Field, len(p.table.Columns))
for i, col := range p.table.Columns {
fields[i] = ColumnToField(col)
}
return fields
return p.fields
}

func (p postgresAdapter) NewIterator() (debezium.RowsIterator, error) {
Expand Down
22 changes: 0 additions & 22 deletions sources/postgres/adapter/adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package adapter
import (
"testing"

"github.com/artie-labs/transfer/lib/debezium"
"github.com/stretchr/testify/assert"

"github.com/artie-labs/reader/lib/postgres"
"github.com/artie-labs/reader/lib/postgres/schema"
)

func TestPostgresAdapter_TableName(t *testing.T) {
Expand Down Expand Up @@ -47,26 +45,6 @@ func TestPostgresAdapter_TopicSuffix(t *testing.T) {
}
}

func TestPostgresAdapter_Fields(t *testing.T) {
table := postgres.Table{
Name: "table1",
Schema: "schema1",
Columns: []schema.Column{
{Name: "col1", Type: schema.Text},
{Name: "col2", Type: schema.Boolean},
{Name: "col3", Type: schema.Array},
},
}
adapter := postgresAdapter{table: table}

expected := []debezium.Field{
{Type: "string", FieldName: "col1"},
{Type: "boolean", FieldName: "col2"},
{Type: "array", FieldName: "col3"},
}
assert.Equal(t, expected, adapter.Fields())
}

func TestPostgresAdapter_PartitionKey(t *testing.T) {
type _tc struct {
name string
Expand Down
Loading

0 comments on commit 3953d64

Please sign in to comment.