Skip to content

Commit

Permalink
mysql: do not allocate in parseOKPacket (#15067)
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <[email protected]>
  • Loading branch information
vmg authored Jan 29, 2024
1 parent 6ac1596 commit ebf7869
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
12 changes: 5 additions & 7 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1525,15 +1525,13 @@ type PacketOK struct {
sessionStateData string
}

func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) {
func (c *Conn) parseOKPacket(packetOK *PacketOK, in []byte) error {
data := &coder{
data: in,
pos: 1, // We already read the type.
}
packetOK := &PacketOK{}

fail := func(format string, args ...any) (*PacketOK, error) {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, format, args...)
fail := func(format string, args ...any) error {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, format, args...)
}

// Affected rows.
Expand Down Expand Up @@ -1578,7 +1576,7 @@ func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) {
if !ok || length == 0 {
// In case we have no more data or a zero length string, there's no additional information so
// we can return the packet.
return packetOK, nil
return nil
}

// Alright, now we need to read each sub packet from the session state change.
Expand Down Expand Up @@ -1615,7 +1613,7 @@ func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) {
}
}

return packetOK, nil
return nil
}

// isErrorPacket determines whether or not the packet is an error packet. Mostly here for
Expand Down
12 changes: 7 additions & 5 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ func TestBasicPackets(t *testing.T) {
require.NotEmpty(data)
assert.EqualValues(data[0], OKPacket, "OKPacket")

packetOk, err := cConn.parseOKPacket(data)
var packetOk PacketOK
err = cConn.parseOKPacket(&packetOk, data)
require.NoError(err)
assert.EqualValues(12, packetOk.affectedRows)
assert.EqualValues(34, packetOk.lastInsertID)
Expand All @@ -272,7 +273,7 @@ func TestBasicPackets(t *testing.T) {
require.NotEmpty(data)
assert.EqualValues(data[0], OKPacket, "OKPacket")

packetOk, err = cConn.parseOKPacket(data)
err = cConn.parseOKPacket(&packetOk, data)
require.NoError(err)
assert.EqualValues(23, packetOk.affectedRows)
assert.EqualValues(45, packetOk.lastInsertID)
Expand All @@ -295,7 +296,7 @@ func TestBasicPackets(t *testing.T) {
require.NotEmpty(data)
assert.True(cConn.isEOFPacket(data), "expected EOF")

packetOk, err = cConn.parseOKPacket(data)
err = cConn.parseOKPacket(&packetOk, data)
require.NoError(err)
assert.EqualValues(12, packetOk.affectedRows)
assert.EqualValues(34, packetOk.lastInsertID)
Expand Down Expand Up @@ -690,7 +691,8 @@ func TestOkPackets(t *testing.T) {
cConn.Capabilities = testCase.cc
sConn.Capabilities = testCase.cc
// parse the packet
packetOk, err := cConn.parseOKPacket(data)
var packetOk PacketOK
err := cConn.parseOKPacket(&packetOk, data)
if testCase.expectedErr != "" {
require.Error(t, err)
require.Equal(t, testCase.expectedErr, err.Error())
Expand All @@ -699,7 +701,7 @@ func TestOkPackets(t *testing.T) {
require.NoError(t, err, "failed to parse OK packet")

// write the ok packet from server
err = sConn.writeOKPacket(packetOk)
err = sConn.writeOKPacket(&packetOk)
require.NoError(t, err, "failed to write OK packet")

// receive the ok packet on client
Expand Down
25 changes: 12 additions & 13 deletions go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,9 @@ func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfield

// ReadQueryResult gets the result from the last written query.
func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, bool, uint16, error) {
var packetOk PacketOK
// Get the result.
colNumber, packetOk, err := c.readComQueryResponse()
colNumber, err := c.readComQueryResponse(&packetOk)
if err != nil {
return nil, false, 0, err
}
Expand Down Expand Up @@ -441,8 +442,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result,
more = (statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = statusFlags
} else {
packetOk, err := c.parseOKPacket(data)
if err != nil {
if err := c.parseOKPacket(&packetOk, data); err != nil {
return nil, false, 0, err
}
warnings = packetOk.warnings
Expand Down Expand Up @@ -497,35 +497,34 @@ func (c *Conn) drainResults() error {
}
}

func (c *Conn) readComQueryResponse() (int, *PacketOK, error) {
func (c *Conn) readComQueryResponse(packetOk *PacketOK) (int, error) {
data, err := c.readEphemeralPacket()
if err != nil {
return 0, nil, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
return 0, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()
if len(data) == 0 {
return 0, nil, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "invalid empty COM_QUERY response packet")
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "invalid empty COM_QUERY response packet")
}

switch data[0] {
case OKPacket:
packetOk, err := c.parseOKPacket(data)
return 0, packetOk, err
return 0, c.parseOKPacket(packetOk, data)
case ErrPacket:
// Error
return 0, nil, ParseErrorPacket(data)
return 0, ParseErrorPacket(data)
case 0xfb:
// Local infile
return 0, nil, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented")
return 0, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented")
}
n, pos, ok := readLenEncInt(data, 0)
if !ok {
return 0, nil, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number")
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number")
}
if pos != len(data) {
return 0, nil, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extra data in COM_QUERY response")
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extra data in COM_QUERY response")
}
return int(n), &PacketOK{}, nil
return int(n), nil
}

//
Expand Down
3 changes: 2 additions & 1 deletion go/mysql/streaming_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func (c *Conn) ExecuteStreamFetch(query string) (err error) {
}

// Get the result.
colNumber, _, err := c.readComQueryResponse()
var packetOk PacketOK
colNumber, err := c.readComQueryResponse(&packetOk)
if err != nil {
return err
}
Expand Down

0 comments on commit ebf7869

Please sign in to comment.