Skip to content

Commit

Permalink
Merge pull request #851 from dolthub/fulghum/doltgres-843
Browse files Browse the repository at this point in the history
Bug fixes for double-quoted relation names
  • Loading branch information
fulghum authored Oct 14, 2024
2 parents e17a304 + 49f98a0 commit 4b25d07
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 26 deletions.
26 changes: 4 additions & 22 deletions server/functions/nextval.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
package functions

import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core"
Expand All @@ -38,26 +35,11 @@ var nextval_text = framework.Function1{
IsNonDeterministic: true,
Strict: true,
Callable: func(ctx *sql.Context, _ [2]pgtypes.DoltgresType, val any) (any, error) {
var schema, sequence string
var err error
pathElems := strings.Split(val.(string), ".")
switch len(pathElems) {
case 1:
schema, err = core.GetCurrentSchema(ctx)
if err != nil {
return nil, err
}
sequence = pathElems[0]
case 2:
schema = pathElems[0]
sequence = pathElems[1]
case 3:
// database is not used atm
schema = pathElems[1]
sequence = pathElems[2]
default:
return nil, fmt.Errorf(`cannot find sequence "%s" to get its nextval`, val.(string))
schema, sequence, err := parseRelationName(ctx, val.(string))
if err != nil {
return nil, err
}

collection, err := core.GetCollectionFromContext(ctx)
if err != nil {
return nil, err
Expand Down
36 changes: 33 additions & 3 deletions server/functions/setval.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
package functions

import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core"
Expand Down Expand Up @@ -54,11 +57,38 @@ var setval_text_int64_boolean = framework.Function3{
return nil, err
}
// TODO: this should take a regclass as the parameter to determine the schema
schema, err := core.GetCurrentSchema(ctx)
schema, relation, err := parseRelationName(ctx, val1.(string))
if err != nil {
return nil, err
}

return val2.(int64), collection.SetVal(schema, val1.(string), val2.(int64), val3.(bool))
return val2.(int64), collection.SetVal(schema, relation, val2.(int64), val3.(bool))
},
}

// parseRelationName parses the schema and relation name from a relation name string, including trimming any
// identifier quotes used in the name. For example, passing in 'public."MyTable"' would return 'public' and 'MyTable'.
func parseRelationName(ctx *sql.Context, name string) (schema string, relation string, err error) {
pathElems := strings.Split(name, ".")
switch len(pathElems) {
case 1:
schema, err = core.GetCurrentSchema(ctx)
if err != nil {
return "", "", err
}
relation = pathElems[0]
case 2:
schema = pathElems[0]
relation = pathElems[1]
case 3:
// database is not used atm
schema = pathElems[1]
relation = pathElems[2]
default:
return "", "", fmt.Errorf(`cannot parse relation: %s`, name)
}

// Trim any quotes from the relation name
relation = strings.Trim(relation, `"`)

return schema, relation, nil
}
4 changes: 3 additions & 1 deletion server/node/copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bufio"
"fmt"
"os"
"strings"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -109,7 +110,8 @@ func (cf *CopyFrom) Validate(ctx *sql.Context) error {

for i, col := range table.Schema() {
name := cf.Columns[i]
if name.String() != col.Name {
nameString := strings.Trim(name.String(), `"`)
if nameString != col.Name {
return fmt.Errorf("invalid column name list for table %s: %v", table.Name(), cf.Columns)
}
}
Expand Down
21 changes: 21 additions & 0 deletions testing/bats/dataloading.bats
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@ teardown() {
[[ "$output" =~ "3 | 03 | 97302 | Guyane" ]] || false
}

# Tests that we can load tabular data dump files that contain quoted column names
@test 'dataloading: tabular import, with quoted column names' {
# Import the data dump and assert the expected output
run query_server -f $BATS_TEST_DIRNAME/dataloading/tab-load-with-quoted-column-names.sql
[ "$status" -eq 0 ]
[[ "$output" =~ "COPY 3" ]] || false
[[ ! "$output" =~ "ERROR" ]] || false

# Check the row count of imported tables
run query_server -c "SELECT count(*) from Regions;"
[ "$status" -eq 0 ]
[[ "$output" =~ "3" ]] || false

# Check the inserted rows
run query_server -c "SELECT * from Regions;"
[ "$status" -eq 0 ]
[[ "$output" =~ "1 | 01 | 97105 | Guadeloupe" ]] || false
[[ "$output" =~ "2 | 02 | 97209 | Martinique" ]] || false
[[ "$output" =~ "3 | 03 | 97302 | Guyane" ]] || false
}

# Tests that we can load tabular data dump files that do not explicitly manage the session's transaction.
@test 'dataloading: tabular import, no explicit tx management' {
# Import the data dump and assert the expected output
Expand Down
16 changes: 16 additions & 0 deletions testing/bats/dataloading/tab-load-with-quoted-column-names.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
BEGIN;

CREATE TABLE Regions (
"Id" SERIAL UNIQUE NOT NULL,
"Code" VARCHAR(4) UNIQUE NOT NULL,
"Capital" VARCHAR(10) NOT NULL,
"Name" VARCHAR(255) UNIQUE NOT NULL
);

COPY regions ("Id", "Code", "Capital", "Name") FROM stdin;
1 01 97105 Guadeloupe
2 02 97209 Martinique
3 03 97302 Guyane
\.

COMMIT;
45 changes: 45 additions & 0 deletions testing/go/sequences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,38 @@ func TestSequences(t *testing.T) {
},
},
},
{
Name: "nextval() with double-quoted identifiers",
SetUpScript: []string{
"CREATE SEQUENCE test_sequence;",
},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT nextval('test_sequence');",
Expected: []sql.Row{
{1},
},
},
{
Query: "SELECT nextval('public.test_sequence');",
Expected: []sql.Row{
{2},
},
},
{
Query: `SELECT nextval('"test_sequence"');`,
Expected: []sql.Row{
{3},
},
},
{
Query: `SELECT nextval('public."test_sequence"');`,
Expected: []sql.Row{
{4},
},
},
},
},
{
Name: "nextval() in filter",
Skip: true, // GMS seems to call nextval once and cache the value, which is incorrect here
Expand Down Expand Up @@ -539,6 +571,19 @@ func TestSequences(t *testing.T) {
Query: "SELECT nextval('test4');",
Expected: []sql.Row{{7}},
},
{
Query: "CREATE SEQUENCE test5;",
Expected: []sql.Row{},
},
{
// test with a double-quoted identifier
Query: `SELECT setval('public."test5"', 100, true);`,
Expected: []sql.Row{{100}},
},
{
Query: "SELECT nextval('test5');",
Expected: []sql.Row{{101}},
},
},
},
{
Expand Down

0 comments on commit 4b25d07

Please sign in to comment.