Skip to content

Commit

Permalink
Formatted repository
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmu committed Nov 18, 2023
1 parent a547a36 commit 6b4e0ad
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 30 deletions.
13 changes: 6 additions & 7 deletions postgres/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

var BufferSize = 2048

// connBuffers maintains a pool of buffers, reusable between connections. These are only used for processing
// connBuffers maintains a pool of buffers, reusable between connections. These are only used for processing
// fixed-length messages where we know we won't exceed the buffer size on a single read.
var connBuffers = sync.Pool{
New: func() any {
Expand All @@ -45,18 +45,17 @@ var headerBuffers = sync.Pool{
},
}


var sliceOfZeroes = make([]byte, BufferSize)

// Receive returns all messages that were sent from the given connection. This checks with all messages that have a
// Header, and have called AddMessageHeader within their init() function. Returns a nil slice if no messages were
// Receive returns all messages that were sent from the given connection. This checks with all messages that have a
// Header, and have called AddMessageHeader within their init() function. Returns a nil slice if no messages were
// matched. This is the recommended way to check for messages when a specific message is not expected.
// Use ReceiveInto or ReceiveIntoAny when expecting specific messages, where it would be an error to receive messages
// different from the expectation.
func Receive(conn net.Conn) (Message, error) {
header := headerBuffers.Get().([]byte)
defer headerBuffers.Put(header)

n, err := conn.Read(header)
if err != nil {
return nil, err
Expand All @@ -73,10 +72,10 @@ func Receive(conn net.Conn) (Message, error) {

// TODO: possibly not every message has a length in this position, need an easy interface to tell us if so
messageLen := int(binary.BigEndian.Uint32(header[1:])) - 4

buffer := iobufpool.Get(messageLen)
defer iobufpool.Put(buffer)

msgBuffer := (*buffer)[:messageLen]
n, err = conn.Read(msgBuffer)
if err != nil {
Expand Down
9 changes: 5 additions & 4 deletions postgres/connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ import (
"testing"
"time"

"github.com/dolthub/doltgresql/postgres/connection"
"github.com/dolthub/doltgresql/postgres/messages"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/require"

"github.com/dolthub/doltgresql/postgres/connection"
"github.com/dolthub/doltgresql/postgres/messages"
)

func TestReceive(t *testing.T) {
Expand Down Expand Up @@ -59,7 +60,7 @@ func TestReceive(t *testing.T) {
t.Run("Receive Query larger than buffer", func(t *testing.T) {
mockBuffer := bytes.NewBuffer([]byte{})
mockConn := &MockConn{buffer: mockBuffer}

message := &pgproto3.Query{
String: "SELECT abc, def, ghi, jkl, mno, pqr, stuv, wxyz, abc, def, ghi, jkl, mno, pqr, stuv, wxyz FROM example",
}
Expand Down Expand Up @@ -112,4 +113,4 @@ func (m *MockConn) SetReadDeadline(t time.Time) error {

func (m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
}
2 changes: 1 addition & 1 deletion postgres/connection/message_decode_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func decode(buffer *decodeBuffer, fields []FieldGroup, iterations int32) error {
// Some calls to decode will have already processed the message header and length to determine the message type,
// so skip those fields when decoding.
if buffer.skipHeader &&
(field.Flags&Header != 0 || field.Flags&MessageLengthInclusive != 0) {
(field.Flags&Header != 0 || field.Flags&MessageLengthInclusive != 0) {
continue
}

Expand Down
30 changes: 15 additions & 15 deletions server/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ func (l *Listener) HandleConnection(conn net.Conn) {

// Postgres has a two-stage procedure for prepared queries. First the query is parsed via a |Parse| message, and
// the result is stored in the |preparedStatements| map by the name provided. Then one or more |Bind| messages
// provide parameters for the query, and the result is stored in |portals|. Finally, a call to |Execute| executes
// the named portal.
// provide parameters for the query, and the result is stored in |portals|. Finally, a call to |Execute| executes
// the named portal.
preparedStatements := make(map[string]ConvertedQuery)
portals := make(map[string]ConvertedQuery)
// Main session loop: read messages one at a time off the connection until we receive a |Terminate| message, in

// Main session loop: read messages one at a time off the connection until we receive a |Terminate| message, in
// which case we hang up, or the connection is closed by the client, which generates an io.EOF from the connection.
for {
message, err := connection.Receive(conn)
Expand All @@ -160,7 +160,7 @@ func (l *Listener) HandleConnection(conn net.Conn) {
}
}

// receiveStarupMessage reads a startup message from the connection given and returns it. Some startup messages will
// receiveStarupMessage reads a startup message from the connection given and returns it. Some startup messages will
// result in the establishment of a new connection, which is also returned.
func (l *Listener) receiveStartupMessage(conn net.Conn, mysqlConn *mysql.Conn) (messages.StartupMessage, net.Conn, error) {
var startupMessage messages.StartupMessage
Expand All @@ -173,15 +173,15 @@ InitialMessageLoop:
messages.GSSENCRequest{})
if err != nil {
if err == io.EOF {
return messages.StartupMessage{}, nil, nil
return messages.StartupMessage{}, nil, nil
}
return messages.StartupMessage{}, nil, err
}

if len(initialMessages) != 1 {
return messages.StartupMessage{}, nil, fmt.Errorf("expected a single message upon starting connection, terminating connection")
}

initialMessage := initialMessages[0]
switch initialMessage := initialMessage.(type) {
case messages.StartupMessage:
Expand Down Expand Up @@ -212,7 +212,7 @@ InitialMessageLoop:
return messages.StartupMessage{}, nil, fmt.Errorf("unexpected initial message, terminating connection")
}
}

return startupMessage, conn, nil
}

Expand All @@ -233,7 +233,7 @@ func (l *Listener) chooseInitialDatabase(conn net.Conn, startupMessage messages.
return err
}
} else {
// If a database isn't specified, then we attempt to connect to a database with the same name as the user,
// If a database isn't specified, then we attempt to connect to a database with the same name as the user,
// ignoring any error
_ = l.cfg.Handler.ComQuery(mysqlConn, fmt.Sprintf("USE `%s`;", mysqlConn.User), func(*sqltypes.Result, bool) error {
return nil
Expand All @@ -243,10 +243,10 @@ func (l *Listener) chooseInitialDatabase(conn net.Conn, startupMessage messages.
}

func (l *Listener) handleMessage(
message connection.Message,
conn net.Conn,
mysqlConn *mysql.Conn,
preparedStatements, portals map[string]ConvertedQuery,
message connection.Message,
conn net.Conn,
mysqlConn *mysql.Conn,
preparedStatements, portals map[string]ConvertedQuery,
) (stop, endOfMessages bool, err error) {
switch message := message.(type) {
case messages.Terminate:
Expand All @@ -259,7 +259,7 @@ func (l *Listener) handleMessage(
if handled || err != nil {
return false, false, err
}

query, err := l.convertQuery(message.String)
if err != nil {
return false, false, err
Expand Down
4 changes: 2 additions & 2 deletions testing/go/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func RunScript(t *testing.T, script ScriptTest) {
_, err := conn.Exec(ctx, query)
require.NoError(t, err)
}

// Run the assertions
for _, assertion := range script.Assertions {
t.Run(assertion.Query, func(t *testing.T) {
Expand Down Expand Up @@ -149,7 +149,7 @@ func CreateServer(t *testing.T, database string) (context.Context, *pgx.Conn, *s
require.Equal(t, 0, *code)

fmt.Printf("port is %d\n", port)

ctx := context.Background()
err := func() error {
// The connection attempt may be made before the server has grabbed the port, so we'll retry the first
Expand Down
2 changes: 1 addition & 1 deletion testing/go/smoke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestSmokeTests(t *testing.T) {
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into test values (1, 1), (2, 2);",
Query: "insert into test values (1, 1), (2, 2);",
SkipResultsCheck: true,
},
{
Expand Down

0 comments on commit 6b4e0ad

Please sign in to comment.