Skip to content

Commit

Permalink
remove some use of errBadConnNoWrite
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Nov 21, 2024
1 parent 41a5fa2 commit 7d18f3e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 44 deletions.
5 changes: 5 additions & 0 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func newBuffer(nc net.Conn) buffer {
}
}

// busy retruns true if the buffer contains some read data.
func (b *buffer) busy() bool {
return b.length > 0
}

// flip replaces the active buffer with the background buffer
// this is a delayed flip that simply increases the buffer counter;
// the actual flip will be performed the next time we call `buffer.fill`
Expand Down
10 changes: 7 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,14 @@ func (mc *mysqlConn) Close() (err error) {
if !mc.closed.Load() {
err = mc.writeCommandPacket(comQuit)
}
mc.close()
return
}

// close closes the network connection and cleare results without sending COM_QUIT.
func (mc *mysqlConn) close() {
mc.cleanup()
mc.clearResult()
return
}

// Closes the network connection and unsets internal variables. Do not call this
Expand Down Expand Up @@ -637,7 +641,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.Load() {
if mc.closed.Load() || mc.buf.busy() {
return driver.ErrBadConn
}

Expand Down Expand Up @@ -671,7 +675,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
// IsValid implements driver.Validator interface
// (From Go 1.15)
func (mc *mysqlConn) IsValid() bool {
return !mc.closed.Load()
return !mc.closed.Load() && !mc.buf.busy()
}

var _ driver.SessionResetter = &mysqlConn{}
Expand Down
50 changes: 16 additions & 34 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,19 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// read packet header
data, err := mc.buf.readNext(4)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
return nil, err
}

// packet length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)

// check packet sync [8 bit]
if data[3] != mc.sequence {
mc.Close()
mc.close()
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
}
Expand All @@ -59,7 +58,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// there was no previous packet
if prevData == nil {
mc.log(ErrMalformPkt)
mc.Close()
mc.close()
return nil, ErrInvalidConn
}

Expand All @@ -69,12 +68,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
return nil, err
}

// return data if this was the last packet
Expand Down Expand Up @@ -125,10 +123,10 @@ func (mc *mysqlConn) writePacket(data []byte) error {

n, err := mc.netConn.Write(data[:4+size])
if err != nil {
mc.cleanup()
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
mc.cleanup()
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
mc.log(err)
Expand Down Expand Up @@ -162,11 +160,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
data, err = mc.readPacket()
if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, "", driver.ErrBadConn
}
return
}

Expand Down Expand Up @@ -312,9 +305,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// Calculate packet length and get buffer with that size
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// ClientFlags [32 bit]
Expand Down Expand Up @@ -404,9 +396,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
data, err := mc.buf.takeBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add the auth data [EOF]
Expand All @@ -424,9 +415,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {

data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
return err
}

// Add command byte
Expand All @@ -443,9 +432,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
pktLen := 1 + len(arg)
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
return err
}

// Add command byte
Expand All @@ -464,9 +451,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {

data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
return err
}

// Add command byte
Expand Down Expand Up @@ -1007,9 +992,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
return err
}

// command [1 byte]
Expand Down Expand Up @@ -1207,8 +1190,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil {
mc.log(err)
return errBadConnNoWrite
return err
}
}

Expand Down
13 changes: 6 additions & 7 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
)

var (
errConnClosed = errors.New("connection is closed")
errConnTooManyReads = errors.New("too many reads")
errConnTooManyWrites = errors.New("too many writes")
)
Expand All @@ -39,7 +38,7 @@ type mockConn struct {

func (m *mockConn) Read(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
return 0, net.ErrClosed
}

m.reads++
Expand All @@ -54,7 +53,7 @@ func (m *mockConn) Read(b []byte) (n int, err error) {
}
func (m *mockConn) Write(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
return 0, net.ErrClosed
}

m.writes++
Expand Down Expand Up @@ -290,8 +289,8 @@ func TestReadPacketFail(t *testing.T) {
// fail to read header
conn.closed = true
_, err = mc.readPacket()
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
if err != net.ErrClosed {
t.Errorf("expected ErrClosed, got %v", err)
}

// reset
Expand All @@ -303,8 +302,8 @@ func TestReadPacketFail(t *testing.T) {
// fail to read body
conn.maxReads = 1
_, err = mc.readPacket()
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
if err != errConnTooManyReads {
t.Errorf("expected errConnTooManyReads, got %#v", err)
}
}

Expand Down

0 comments on commit 7d18f3e

Please sign in to comment.