Skip to content

Commit

Permalink
Add a stateless postgres describe table function (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Feb 15, 2024
1 parent cb3c59d commit 45f8eda
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 53 deletions.
31 changes: 7 additions & 24 deletions lib/postgres/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,23 @@ import (
)

func (t *Table) RetrieveColumns(db *sql.DB) error {
describeQuery, describeArgs := schema.DescribeTableQuery(schema.DescribeTableArgs{
Name: t.Name,
Schema: t.Schema,
})

rows, err := db.Query(describeQuery, describeArgs...)
cols, err := schema.DescribeTable(db, t.Schema, t.Name)
if err != nil {
return fmt.Errorf("failed to query: %s, args: %v, err: %w", describeQuery, describeArgs, err)
return fmt.Errorf("failed to describe table %s.%s: %w", t.Schema, t.Name, err)
}

for rows.Next() {
var colName string
var colKind string
var numericPrecision *string
var numericScale *string
var udtName *string
err = rows.Scan(&colName, &colKind, &numericPrecision, &numericScale, &udtName)
if err != nil {
return err
}

dataType, opts := schema.ColKindToDataType(colKind, numericPrecision, numericScale, udtName)
if dataType == schema.InvalidDataType {
for _, col := range cols {
if col.Type == schema.InvalidDataType {
slog.Warn("Column type did not get mapped in our message schema, so it will not be automatically created by transfer",
slog.String("colName", colName),
slog.String("colKind", colKind),
slog.String("colName", col.Name),
)
} else {
t.Fields.AddField(colName, dataType, opts)
t.Fields.AddField(col.Name, col.Type, col.Opts)
}
}

query := fmt.Sprintf("SELECT * FROM %s LIMIT 1", pgx.Identifier{t.Schema, t.Name}.Sanitize())
rows, err = db.Query(query)
rows, err := db.Query(query)
if err != nil {
return fmt.Errorf("failed to query, query: %v, err: %w", query, err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/postgres/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestParse(t *testing.T) {

for _, tc := range tcs {
fields := pgDebezium.NewFields()
dataType, opts := schema.ColKindToDataType(tc.colKind, nil, nil, tc.udtName)
dataType, opts := schema.ParseColumnDataType(tc.colKind, nil, nil, tc.udtName)
fields.AddField(tc.colName, dataType, opts)

value, err := ParseValue(fields, ParseValueArgs{
Expand Down
48 changes: 41 additions & 7 deletions lib/postgres/schema/schema.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package schema

import (
"database/sql"
"fmt"
"log/slog"
"strings"

"github.com/artie-labs/transfer/lib/ptr"
Expand Down Expand Up @@ -57,21 +60,52 @@ type Opts struct {
Precision *string
}

type DescribeTableArgs struct {
Name string
Schema string
type Column struct {
Name string
Type DataType
Opts *Opts
}

const describeTableQuery = `
SELECT column_name, data_type, numeric_precision, numeric_scale, udt_name
FROM information_schema.columns
WHERE table_name = $1 AND table_schema = $2`
WHERE table_schema = $1 AND table_name = $2`

func DescribeTableQuery(args DescribeTableArgs) (string, []any) {
return strings.TrimSpace(describeTableQuery), []any{args.Name, args.Schema}
func DescribeTable(db *sql.DB, _schema, table string) ([]Column, error) {
query := strings.TrimSpace(describeTableQuery)
rows, err := db.Query(query, _schema, table)
if err != nil {
return nil, fmt.Errorf("failed to run query: %s: %w", query, err)
}
defer rows.Close()

var cols []Column
for rows.Next() {
var colName string
var colType string
var numericPrecision *string
var numericScale *string
var udtName *string
err = rows.Scan(&colName, &colType, &numericPrecision, &numericScale, &udtName)
if err != nil {
return nil, err
}

dataType, opts := ParseColumnDataType(colType, numericPrecision, numericScale, udtName)
if dataType == InvalidDataType {
slog.Warn("Unable to identify column type", slog.String("colName", colName), slog.String("colType", colType))
}

cols = append(cols, Column{
Name: colName,
Type: dataType,
Opts: opts,
})
}
return cols, nil
}

func ColKindToDataType(colKind string, precision, scale, udtName *string) (DataType, *Opts) {
func ParseColumnDataType(colKind string, precision, scale, udtName *string) (DataType, *Opts) {
colKind = strings.ToLower(colKind)
switch colKind {
case "point":
Expand Down
23 changes: 2 additions & 21 deletions lib/postgres/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,7 @@ import (
"github.com/stretchr/testify/assert"
)

func TestDescribeTableQuery(t *testing.T) {
{
query, args := DescribeTableQuery(DescribeTableArgs{
Name: "name",
Schema: "schema",
})
assert.Equal(t, "SELECT column_name, data_type, numeric_precision, numeric_scale, udt_name\nFROM information_schema.columns\nWHERE table_name = $1 AND table_schema = $2", query)
assert.Equal(t, []any{"name", "schema"}, args)
}
// test quotes in table name or schema are left alone
{
_, args := DescribeTableQuery(DescribeTableArgs{
Name: `na"me`,
Schema: `s'ch"em'a`,
})
assert.Equal(t, []any{`na"me`, `s'ch"em'a`}, args)
}
}

func TestColKindToDataType(t *testing.T) {
func TestParseColumnDataType(t *testing.T) {
type _testCase struct {
name string
colKind string
Expand Down Expand Up @@ -140,7 +121,7 @@ func TestColKindToDataType(t *testing.T) {
}

for _, testCase := range testCases {
dataType, opts := ColKindToDataType(testCase.colKind, testCase.precision, testCase.scale, testCase.udtName)
dataType, opts := ParseColumnDataType(testCase.colKind, testCase.precision, testCase.scale, testCase.udtName)
assert.Equal(t, testCase.expectedDataType, dataType, testCase.name)
assert.Equal(t, testCase.expectedOpts, opts, testCase.name)
}
Expand Down

0 comments on commit 45f8eda

Please sign in to comment.