Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed INT8 handling, support for testing statements, proper connection error handling #55

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions postgres/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,20 @@ TopLevelLoop:
return allPossibleMessages, nil
}

// DiscardToSync discards all messages in the buffer until a Sync has been reached. If a Sync was never sent, then this
// may cause the connection to lock until the client send a Sync, as their request structure was malformed.
func DiscardToSync(conn net.Conn) error {
for {
message, err := Receive(conn)
if err != nil {
return err
}
if message.DefaultMessage().Name == "Sync" {
return nil
}
}
}

// Send sends the given message over the connection.
func Send(conn net.Conn, message Message) error {
encodedMessage, err := message.Encode()
Expand Down
13 changes: 11 additions & 2 deletions postgres/messages/row_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,22 @@ func (m RowDescription) DefaultMessage() *connection.MessageFormat {
}

// VitessFieldToDataTypeObjectID returns a type, as defined by Vitess, into a type as defined by Postgres.
// OIDs can be obtained with the following query: `SELECT oid, typname FROM pg_type ORDER BY 1;`
func VitessFieldToDataTypeObjectID(field *query.Field) (int32, error) {
switch field.Type {
case query.Type_INT8:
return 17, nil
// Postgres doesn't make use of a small integer type for integer returns, which presents a bit of a conundrum.
// GMS defines boolean operations as the smallest integer type, while Postgres has an explicit bool type.
// We can't always assume that `INT8` means bool, since it could just be a small integer. As a result, we'll
// always return this as though it's an `INT32`, which also means that we can't support bools right now.
// OIDs 16 (bool) and 18 (char, ASCII only?) are the only single-byte types as far as I'm aware.
return 23, nil
case query.Type_INT16:
return 21, nil
// The technically correct OID is 21 (2-byte integer), however it seems like some clients don't actually expect
// this, so I'm not sure when it's actually used by Postgres. Because of this, we'll just pretend it's an `INT32`.
return 23, nil
case query.Type_INT24:
// Postgres doesn't have a 3-byte integer type, so just pretend it's `INT32`.
return 23, nil
case query.Type_INT32:
return 23, nil
Expand Down
27 changes: 26 additions & 1 deletion server/implicit_commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

package server

import "strings"
import (
"fmt"
"strings"

"github.com/dolthub/doltgresql/postgres/parser/parser"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
)

// implicitCommitStatements are a collection of statements that perform an implicit COMMIT before executing. Such
// statements cannot have their effects reversed by rolling back a transaction or rolling back to a savepoint.
Expand All @@ -39,3 +45,22 @@ func ImplicitlyCommits(statement string) bool {
}
return false
}

// HandleImplicitCommitStatement returns a statement that can reverse the given statement, such that it appears to have
// never executed. This only applies to statements that implicitly commit, as determined by ImplicitlyCommits.
func HandleImplicitCommitStatement(statement string) (reverseStatement string, handled bool) {
s, err := parser.Parse(statement)
if err != nil || len(s) != 1 {
return "", false
}
switch node := s[0].AST.(type) {
case *tree.CreateDatabase:
return fmt.Sprintf("DROP DATABASE %s", string(node.Name)), true
case *tree.CreateTable:
return fmt.Sprintf("DROP TABLE %s", node.Table.String()), true
case *tree.CreateView:
return fmt.Sprintf("DROP VIEW %s", node.Name.String()), true
default:
return "", false
}
}
34 changes: 27 additions & 7 deletions server/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,15 @@ func (l *Listener) HandleConnection(conn net.Conn) {
}

stop, endOfMessages, err := l.handleMessage(message, conn, mysqlConn, preparedStatements, portals)
if err != nil || endOfMessages {
// TODO: do we need to clear out the connection here? If so, we need to read from it without blocking
if err != nil {
if !endOfMessages {
if syncErr := connection.DiscardToSync(conn); syncErr != nil {
fmt.Println(syncErr.Error())
}
}
l.endOfMessages(conn, err)
} else if endOfMessages {
l.endOfMessages(conn, nil)
}

if stop {
Expand Down Expand Up @@ -262,7 +268,7 @@ func (l *Listener) handleMessage(

query, err := l.convertQuery(message.String)
if err != nil {
return false, false, err
return false, true, err
}

// The Deallocate message must not get passed to the engine, since we handle allocation / deallocation of
Expand All @@ -271,7 +277,7 @@ func (l *Listener) handleMessage(
case *sqlparser.Deallocate:
_, ok := preparedStatements[stmt.Name]
if !ok {
return false, false, fmt.Errorf("prepared statement %s does not exist", stmt.Name)
return false, true, fmt.Errorf("prepared statement %s does not exist", stmt.Name)
}
delete(preparedStatements, stmt.Name)

Expand Down Expand Up @@ -309,7 +315,7 @@ func (l *Listener) handleMessage(
portals[message.DestinationPortal] = preparedStatements[message.SourcePreparedStatement]
return false, false, connection.Send(conn, messages.BindComplete{})
default:
return false, false, fmt.Errorf(`Unhandled message "%s"`, message.DefaultMessage().Name)
return false, true, fmt.Errorf(`Unhandled message "%s"`, message.DefaultMessage().Name)
}
}

Expand Down Expand Up @@ -410,7 +416,7 @@ func (l *Listener) execute(conn net.Conn, mysqlConn *mysql.Conn, query Converted
}

// describe handles the description of the given query. This will post the ParameterDescription and RowDescription messages.
func (l *Listener) describe(conn net.Conn, mysqlConn *mysql.Conn, message messages.Describe, statement ConvertedQuery) error {
func (l *Listener) describe(conn net.Conn, mysqlConn *mysql.Conn, message messages.Describe, statement ConvertedQuery) (err error) {
//TODO: fully support prepared statements
if err := connection.Send(conn, messages.ParameterDescription{
ObjectIDs: nil,
Expand All @@ -420,7 +426,21 @@ func (l *Listener) describe(conn net.Conn, mysqlConn *mysql.Conn, message messag

//TODO: properly handle these statements
if ImplicitlyCommits(statement.String) {
return fmt.Errorf("We do not yet support the Describe message for the given statement")
if reverseStatement, ok := HandleImplicitCommitStatement(statement.String); ok {
// We have a reverse statement that can function as a workaround for the lack of proper rollback support.
// This does mean that we'll still create an implicit commit, but we can fix that whenever we add proper
// transaction support.
defer func() {
// If there's an error, then we don't want to execute the reverse statement
if err == nil {
_ = l.cfg.Handler.ComQuery(mysqlConn, reverseStatement, func(_ *sqltypes.Result, _ bool) error {
return nil
})
}
}()
} else {
return fmt.Errorf("We do not yet support the Describe message for the given statement")
}
}
// We'll start a transaction, so that we can later rollback any changes that were made.
//TODO: handle the case where we are already in a transaction (SAVEPOINT will sometimes fail it seems?)
Expand Down
22 changes: 16 additions & 6 deletions testing/go/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package _go
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"net"
Expand Down Expand Up @@ -106,8 +107,9 @@ func RunScript(t *testing.T, script ScriptTest) {
} else {
rows, err := conn.Query(ctx, assertion.Query)
require.NoError(t, err)
defer rows.Close()
assert.Equal(t, NormalizeRows(assertion.Expected), ReadRows(t, rows))
readRows, err := ReadRows(rows)
require.NoError(t, err)
assert.Equal(t, NormalizeRows(assertion.Expected), readRows)
}
})
}
Expand Down Expand Up @@ -179,19 +181,27 @@ func CreateServer(t *testing.T, database string) (context.Context, *pgx.Conn, *s
return ctx, conn, serverClosed
}

// ReadRows reads all of the given rows into a slice. This also normalizes all of the rows. Does not call Close() on the rows.
func ReadRows(t *testing.T, rows pgx.Rows) []sql.Row {
// ReadRows reads all of the given rows into a slice, then closes the rows. This also normalizes all of the rows.
func ReadRows(rows pgx.Rows) (readRows []sql.Row, err error) {
defer func() {
err = errors.Join(err, rows.Err())
}()
var slice []sql.Row
for rows.Next() {
row, err := rows.Values()
require.NoError(t, err)
if err != nil {
return nil, err
}
slice = append(slice, row)
}
return NormalizeRows(slice)
return NormalizeRows(slice), nil
}

// NormalizeRow normalizes each value's type, as the tests only want to compare values. Returns a new row.
func NormalizeRow(row sql.Row) sql.Row {
if len(row) == 0 {
return nil
}
newRow := make(sql.Row, len(row))
for i := range row {
switch val := row[i].(type) {
Expand Down
40 changes: 36 additions & 4 deletions testing/go/smoke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,54 @@ import (
func TestSmokeTests(t *testing.T) {
RunScripts(t, []ScriptTest{
{
Name: "simple statements",
Name: "Simple statements",
SetUpScript: []string{
"CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 BIGINT);",
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into test values (1, 1), (2, 2);",
SkipResultsCheck: true,
Query: "CREATE TABLE test2 (pk BIGINT PRIMARY KEY, v1 BIGINT);",
Expected: []sql.Row{},
},
{
Query: "INSERT INTO test VALUES (1, 1), (2, 2);",
Expected: []sql.Row{},
},
{
Query: "select * from test;",
Query: "INSERT INTO test2 VALUES (3, 3), (4, 4);",
Expected: []sql.Row{},
},
{
Query: "SELECT * FROM test;",
Expected: []sql.Row{
{1, 1},
{2, 2},
},
},
{
Query: "SELECT * FROM test2;",
Expected: []sql.Row{
{3, 3},
{4, 4},
},
},
},
},
{
Name: "Boolean results",
Assertions: []ScriptTestAssertion{
{
Query: "SELECT 1 IN (2);",
Expected: []sql.Row{
{0},
},
},
{
Query: "SELECT 2 IN (2);",
Expected: []sql.Row{
{1},
},
},
},
},
{
Expand Down
4 changes: 3 additions & 1 deletion testing/go/ssl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,7 @@ func TestSSL(t *testing.T) {
require.NoError(t, err)
rows, err := conn.Query(ctx, "SELECT * FROM test;")
require.NoError(t, err)
assert.Equal(t, NormalizeRows([]sql.Row{{3645, 37643}}), ReadRows(t, rows))
readRows, err := ReadRows(rows)
require.NoError(t, err)
assert.Equal(t, NormalizeRows([]sql.Row{{3645, 37643}}), readRows)
}