From 45f8edadabfd2d64a4db8e5a2eb02256bbefb1d6 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Wed, 14 Feb 2024 16:19:46 -0800 Subject: [PATCH] Add a stateless postgres describe table function (#106) --- lib/postgres/columns.go | 31 +++++-------------- lib/postgres/parse_test.go | 2 +- lib/postgres/schema/schema.go | 48 +++++++++++++++++++++++++----- lib/postgres/schema/schema_test.go | 23 ++------------ 4 files changed, 51 insertions(+), 53 deletions(-) diff --git a/lib/postgres/columns.go b/lib/postgres/columns.go index d9adf1dc..88a753e3 100644 --- a/lib/postgres/columns.go +++ b/lib/postgres/columns.go @@ -10,40 +10,23 @@ import ( ) func (t *Table) RetrieveColumns(db *sql.DB) error { - describeQuery, describeArgs := schema.DescribeTableQuery(schema.DescribeTableArgs{ - Name: t.Name, - Schema: t.Schema, - }) - - rows, err := db.Query(describeQuery, describeArgs...) + cols, err := schema.DescribeTable(db, t.Schema, t.Name) if err != nil { - return fmt.Errorf("failed to query: %s, args: %v, err: %w", describeQuery, describeArgs, err) + return fmt.Errorf("failed to describe table %s.%s: %w", t.Schema, t.Name, err) } - for rows.Next() { - var colName string - var colKind string - var numericPrecision *string - var numericScale *string - var udtName *string - err = rows.Scan(&colName, &colKind, &numericPrecision, &numericScale, &udtName) - if err != nil { - return err - } - - dataType, opts := schema.ColKindToDataType(colKind, numericPrecision, numericScale, udtName) - if dataType == schema.InvalidDataType { + for _, col := range cols { + if col.Type == schema.InvalidDataType { slog.Warn("Column type did not get mapped in our message schema, so it will not be automatically created by transfer", - slog.String("colName", colName), - slog.String("colKind", colKind), + slog.String("colName", col.Name), ) } else { - t.Fields.AddField(colName, dataType, opts) + t.Fields.AddField(col.Name, col.Type, col.Opts) } } query := fmt.Sprintf("SELECT * FROM %s LIMIT 1", pgx.Identifier{t.Schema, t.Name}.Sanitize()) - rows, err = db.Query(query) + rows, err := db.Query(query) if err != nil { return fmt.Errorf("failed to query, query: %v, err: %w", query, err) } diff --git a/lib/postgres/parse_test.go b/lib/postgres/parse_test.go index e97c42f5..8b62aeb7 100644 --- a/lib/postgres/parse_test.go +++ b/lib/postgres/parse_test.go @@ -97,7 +97,7 @@ func TestParse(t *testing.T) { for _, tc := range tcs { fields := pgDebezium.NewFields() - dataType, opts := schema.ColKindToDataType(tc.colKind, nil, nil, tc.udtName) + dataType, opts := schema.ParseColumnDataType(tc.colKind, nil, nil, tc.udtName) fields.AddField(tc.colName, dataType, opts) value, err := ParseValue(fields, ParseValueArgs{ diff --git a/lib/postgres/schema/schema.go b/lib/postgres/schema/schema.go index 7a725b1b..036fc040 100644 --- a/lib/postgres/schema/schema.go +++ b/lib/postgres/schema/schema.go @@ -1,6 +1,9 @@ package schema import ( + "database/sql" + "fmt" + "log/slog" "strings" "github.com/artie-labs/transfer/lib/ptr" @@ -57,21 +60,52 @@ type Opts struct { Precision *string } -type DescribeTableArgs struct { - Name string - Schema string +type Column struct { + Name string + Type DataType + Opts *Opts } const describeTableQuery = ` SELECT column_name, data_type, numeric_precision, numeric_scale, udt_name FROM information_schema.columns -WHERE table_name = $1 AND table_schema = $2` +WHERE table_schema = $1 AND table_name = $2` -func DescribeTableQuery(args DescribeTableArgs) (string, []any) { - return strings.TrimSpace(describeTableQuery), []any{args.Name, args.Schema} +func DescribeTable(db *sql.DB, _schema, table string) ([]Column, error) { + query := strings.TrimSpace(describeTableQuery) + rows, err := db.Query(query, _schema, table) + if err != nil { + return nil, fmt.Errorf("failed to run query: %s: %w", query, err) + } + defer rows.Close() + + var cols []Column + for rows.Next() { + var colName string + var colType string + var numericPrecision *string + var numericScale *string + var udtName *string + err = rows.Scan(&colName, &colType, &numericPrecision, &numericScale, &udtName) + if err != nil { + return nil, err + } + + dataType, opts := ParseColumnDataType(colType, numericPrecision, numericScale, udtName) + if dataType == InvalidDataType { + slog.Warn("Unable to identify column type", slog.String("colName", colName), slog.String("colType", colType)) + } + + cols = append(cols, Column{ + Name: colName, + Type: dataType, + Opts: opts, + }) + } + return cols, nil } -func ColKindToDataType(colKind string, precision, scale, udtName *string) (DataType, *Opts) { +func ParseColumnDataType(colKind string, precision, scale, udtName *string) (DataType, *Opts) { colKind = strings.ToLower(colKind) switch colKind { case "point": diff --git a/lib/postgres/schema/schema_test.go b/lib/postgres/schema/schema_test.go index 82346a0d..f90be752 100644 --- a/lib/postgres/schema/schema_test.go +++ b/lib/postgres/schema/schema_test.go @@ -7,26 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestDescribeTableQuery(t *testing.T) { - { - query, args := DescribeTableQuery(DescribeTableArgs{ - Name: "name", - Schema: "schema", - }) - assert.Equal(t, "SELECT column_name, data_type, numeric_precision, numeric_scale, udt_name\nFROM information_schema.columns\nWHERE table_name = $1 AND table_schema = $2", query) - assert.Equal(t, []any{"name", "schema"}, args) - } - // test quotes in table name or schema are left alone - { - _, args := DescribeTableQuery(DescribeTableArgs{ - Name: `na"me`, - Schema: `s'ch"em'a`, - }) - assert.Equal(t, []any{`na"me`, `s'ch"em'a`}, args) - } -} - -func TestColKindToDataType(t *testing.T) { +func TestParseColumnDataType(t *testing.T) { type _testCase struct { name string colKind string @@ -140,7 +121,7 @@ func TestColKindToDataType(t *testing.T) { } for _, testCase := range testCases { - dataType, opts := ColKindToDataType(testCase.colKind, testCase.precision, testCase.scale, testCase.udtName) + dataType, opts := ParseColumnDataType(testCase.colKind, testCase.precision, testCase.scale, testCase.udtName) assert.Equal(t, testCase.expectedDataType, dataType, testCase.name) assert.Equal(t, testCase.expectedOpts, opts, testCase.name) }