Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Oct 3, 2024
1 parent 839311e commit 5271894
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 39 deletions.
30 changes: 22 additions & 8 deletions lib/debezium/converters/bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@ import (
"github.com/artie-labs/transfer/lib/typing"
)

type BitConverter struct{}
func NewBitConverter(charMaxLength int) BitConverter {
return BitConverter{
charMaxLength: charMaxLength,
}
}

type BitConverter struct {
charMaxLength int
}

func (BitConverter) ToField(name string) debezium.Field {
return debezium.Field{
Expand All @@ -15,17 +23,23 @@ func (BitConverter) ToField(name string) debezium.Field {
}
}

func (BitConverter) Convert(value any) (any, error) {
// This will be 0 (false) or 1 (true)
func (b BitConverter) Convert(value any) (any, error) {
stringValue, err := typing.AssertType[string](value)
if err != nil {
return nil, err
}

if stringValue == "0" {
return false, nil
} else if stringValue == "1" {
return true, nil
switch b.charMaxLength {
case 0:
return nil, fmt.Errorf("bit converter failed: invalid char max length")
case 1:
if stringValue == "0" {
return false, nil
} else if stringValue == "1" {
return true, nil
}
return nil, fmt.Errorf(`string value %q is not in ["0", "1"]`, value)
default:
return stringValue, nil
}
return nil, fmt.Errorf(`string value %q is not in ["0", "1"]`, value)
}
20 changes: 13 additions & 7 deletions lib/postgres/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ const (
)

type Opts struct {
Scale uint16
Precision int
Scale uint16
Precision int
CharMaxLength int
}

type Column = column.Column[DataType, Opts]

const describeTableQuery = `
SELECT column_name, data_type, numeric_precision, numeric_scale, udt_name
SELECT column_name, data_type, numeric_precision, numeric_scale, udt_name, character_maximum_length
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2`

Expand All @@ -73,12 +74,13 @@ func DescribeTable(db *sql.DB, _schema, table string) ([]Column, error) {
var numericPrecision *int
var numericScale *uint16
var udtName *string
err = rows.Scan(&colName, &colType, &numericPrecision, &numericScale, &udtName)
var charMaxLength *int
err = rows.Scan(&colName, &colType, &numericPrecision, &numericScale, &udtName, &charMaxLength)
if err != nil {
return nil, err
}

dataType, opts, err := ParseColumnDataType(colType, numericPrecision, numericScale, udtName)
dataType, opts, err := parseColumnDataType(colType, numericPrecision, numericScale, charMaxLength, udtName)
if err != nil {
return nil, fmt.Errorf("unable to identify type %q for column %q", colType, colName)
}
Expand All @@ -92,11 +94,15 @@ func DescribeTable(db *sql.DB, _schema, table string) ([]Column, error) {
return cols, nil
}

func ParseColumnDataType(colKind string, precision *int, scale *uint16, udtName *string) (DataType, *Opts, error) {
func parseColumnDataType(colKind string, precision *int, scale *uint16, charMaxLength *int, udtName *string) (DataType, *Opts, error) {
colKind = strings.ToLower(colKind)
switch colKind {
case "bit":
return Bit, nil, nil
if charMaxLength == nil {
return -1, nil, fmt.Errorf("invalid bit column: missing character maximum length")
}

return Bit, &Opts{CharMaxLength: *charMaxLength}, nil
case "boolean":
return Boolean, nil, nil
case "smallint":
Expand Down
62 changes: 39 additions & 23 deletions lib/postgres/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
func TestParseColumnDataType(t *testing.T) {
{
// Array
dataType, opts, err := ParseColumnDataType("ARRAY", nil, nil, nil)
dataType, opts, err := parseColumnDataType("ARRAY", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Array, dataType)
assert.Nil(t, opts)
Expand All @@ -18,127 +18,143 @@ func TestParseColumnDataType(t *testing.T) {
// String
{
// Character varying
dataType, opts, err := ParseColumnDataType("character varying", nil, nil, nil)
dataType, opts, err := parseColumnDataType("character varying", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Text, dataType)
assert.Nil(t, opts)
}
{
// Character
dataType, opts, err := ParseColumnDataType("character", nil, nil, nil)
dataType, opts, err := parseColumnDataType("character", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Text, dataType)
assert.Nil(t, opts)
}
}
{
// bit
dataType, opts, err := ParseColumnDataType("bit", nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Bit, dataType)
assert.Nil(t, opts)
{
// bit (char max length not specified)
dataType, opts, err := parseColumnDataType("bit", nil, nil, nil, nil)
assert.ErrorContains(t, err, "invalid bit column: missing character maximum length")
assert.Equal(t, -1, int(dataType))
assert.Nil(t, opts)
}
{
// bit (1)
dataType, opts, err := parseColumnDataType("bit", nil, nil, typing.ToPtr(1), nil)
assert.NoError(t, err)
assert.Equal(t, Bit, dataType)
assert.Equal(t, 1, opts.CharMaxLength)
}
{
// bit (5)
dataType, opts, err := parseColumnDataType("bit", nil, nil, typing.ToPtr(5), nil)
assert.NoError(t, err)
assert.Equal(t, Bit, dataType)
assert.Equal(t, 5, opts.CharMaxLength)
}
}
{
// boolean
dataType, opts, err := ParseColumnDataType("boolean", nil, nil, nil)
dataType, opts, err := parseColumnDataType("boolean", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Boolean, dataType)
assert.Nil(t, opts)
}
{
// interval
dataType, opts, err := ParseColumnDataType("interval", nil, nil, nil)
dataType, opts, err := parseColumnDataType("interval", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Interval, dataType)
assert.Nil(t, opts)
}
{
// time with time zone
dataType, opts, err := ParseColumnDataType("time with time zone", nil, nil, nil)
dataType, opts, err := parseColumnDataType("time with time zone", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, TimeWithTimeZone, dataType)
assert.Nil(t, opts)
}
{
// time without time zone
dataType, opts, err := ParseColumnDataType("time without time zone", nil, nil, nil)
dataType, opts, err := parseColumnDataType("time without time zone", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Time, dataType)
assert.Nil(t, opts)
}
{
// date
dataType, opts, err := ParseColumnDataType("date", nil, nil, nil)
dataType, opts, err := parseColumnDataType("date", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Date, dataType)
assert.Nil(t, opts)
}
{
// inet
dataType, opts, err := ParseColumnDataType("inet", nil, nil, nil)
dataType, opts, err := parseColumnDataType("inet", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Text, dataType)
assert.Nil(t, opts)
}
{
// numeric
dataType, opts, err := ParseColumnDataType("numeric", nil, nil, nil)
dataType, opts, err := parseColumnDataType("numeric", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, VariableNumeric, dataType)
assert.Nil(t, opts)
}
{
// numeric - with scale + precision
dataType, opts, err := ParseColumnDataType("numeric", typing.ToPtr(3), typing.ToPtr(uint16(2)), nil)
dataType, opts, err := parseColumnDataType("numeric", typing.ToPtr(3), typing.ToPtr(uint16(2)), nil, nil)
assert.NoError(t, err)
assert.Equal(t, Numeric, dataType)
assert.Equal(t, &Opts{Scale: 2, Precision: 3}, opts)
}
{
// Variable numeric
dataType, opts, err := ParseColumnDataType("variable numeric", nil, nil, nil)
dataType, opts, err := parseColumnDataType("variable numeric", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, VariableNumeric, dataType)
assert.Nil(t, opts)
}
{
// Money
dataType, opts, err := ParseColumnDataType("money", nil, nil, nil)
dataType, opts, err := parseColumnDataType("money", nil, nil, nil, nil)
assert.NoError(t, err)
assert.Equal(t, Money, dataType)
assert.Nil(t, opts)
}
{
// hstore
dataType, opts, err := ParseColumnDataType("user-defined", nil, nil, typing.ToPtr("hstore"))
dataType, opts, err := parseColumnDataType("user-defined", nil, nil, nil, typing.ToPtr("hstore"))
assert.NoError(t, err)
assert.Equal(t, HStore, dataType)
assert.Nil(t, opts)
}
{
// geometry
dataType, opts, err := ParseColumnDataType("user-defined", nil, nil, typing.ToPtr("geometry"))
dataType, opts, err := parseColumnDataType("user-defined", nil, nil, nil, typing.ToPtr("geometry"))
assert.NoError(t, err)
assert.Equal(t, Geometry, dataType)
assert.Nil(t, opts)
}
{
// geography
dataType, opts, err := ParseColumnDataType("user-defined", nil, nil, typing.ToPtr("geography"))
dataType, opts, err := parseColumnDataType("user-defined", nil, nil, nil, typing.ToPtr("geography"))
assert.NoError(t, err)
assert.Equal(t, Geography, dataType)
assert.Nil(t, opts)
}
{
// user-defined text
dataType, opts, err := ParseColumnDataType("user-defined", nil, nil, typing.ToPtr("foo"))
dataType, opts, err := parseColumnDataType("user-defined", nil, nil, nil, typing.ToPtr("foo"))
assert.NoError(t, err)
assert.Equal(t, UserDefinedText, dataType)
assert.Nil(t, opts)
}
{
// unsupported
dataType, opts, err := ParseColumnDataType("foo", nil, nil, nil)
dataType, opts, err := parseColumnDataType("foo", nil, nil, nil, nil)
assert.ErrorContains(t, err, `unknown data type: "foo"`)
assert.Equal(t, -1, int(dataType))
assert.Nil(t, opts)
Expand Down
6 changes: 5 additions & 1 deletion sources/postgres/adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ func (p PostgresAdapter) PartitionKeys() []string {
func valueConverterForType(dataType schema.DataType, opts *schema.Opts) (converters.ValueConverter, error) {
switch dataType {
case schema.Bit:
return converters.BitConverter{}, nil
if opts == nil {
return nil, fmt.Errorf("missing options for bit data type")
}

return converters.NewBitConverter(opts.CharMaxLength), nil
case schema.Boolean:
return converters.BooleanPassthrough{}, nil
case schema.Int16:
Expand Down

0 comments on commit 5271894

Please sign in to comment.