From 943108533f9cc360f4e543e03578f10ff0b395bc Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Sat, 8 Feb 2020 14:40:25 +0900 Subject: [PATCH] Add test for packet read io error (#94) --- serve.go | 2 +- serve_test.go | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/serve.go b/serve.go index 0725e21..ea6d220 100644 --- a/serve.go +++ b/serve.go @@ -28,7 +28,7 @@ func readPacket(r io.Reader) (packetType, byte, []byte, error) { var remainingLength int for shift := uint(0); ; shift += 7 { remainingLength |= (int(buf[1]) & 0x7F) << shift - if !(buf[1]&0x80 != 0) { + if buf[1]&0x80 == 0 { break } if _, err := io.ReadFull(r, buf[1:]); err != nil { diff --git a/serve_test.go b/serve_test.go index 9b2d486..9ecc73c 100644 --- a/serve_test.go +++ b/serve_test.go @@ -15,7 +15,9 @@ package mqtt import ( + "bytes" "context" + "io" "net" "testing" "time" @@ -116,3 +118,24 @@ func TestServeParseError(t *testing.T) { }) } } + +func TestReadPacketError(t *testing.T) { + pkt := []byte{0x10, 0x80, 0x01} + pkt = append(pkt, make([]byte, 128)...) + + // Ensure full packet doesn't error + _, _, _, err := readPacket(bytes.NewReader(pkt)) + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + for i := 1; i < len(pkt)-1; i++ { + _, _, _, err := readPacket(bytes.NewReader(pkt[:i])) + if err != io.ErrUnexpectedEOF && err != io.EOF { + t.Fatalf( + "Expected error for %d: '%v' or %v'', got: '%v'", + i, io.ErrUnexpectedEOF, io.EOF, err, + ) + } + } +}