Skip to content

Commit

Permalink
Unify packetize code (#40)
Browse files Browse the repository at this point in the history
* Unify packetize code
* Add some more tests
  • Loading branch information
at-wat authored Dec 24, 2019
1 parent 1f811b8 commit ced04de
Show file tree
Hide file tree
Showing 13 changed files with 338 additions and 188 deletions.
58 changes: 0 additions & 58 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,12 @@
package mqtt

import (
"bytes"
"context"
"errors"
"net"
"testing"
"time"
)

func TestConnect(t *testing.T) {
ca, cb := net.Pipe()
cli := &BaseClient{Transport: cb}

go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, _ = cli.Connect(ctx, "cli",
WithUserNamePassword("user", "pass"),
WithKeepAlive(0x0123),
WithCleanSession(true),
WithProtocolLevel(ProtocolLevel4),
WithWill(&Message{QoS: QoS1, Topic: "topic", Payload: []byte{0x01}}),
)
}()

b := make([]byte, 100)
n, err := ca.Read(b)
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

expected := []byte{
0x10, // CONNECT
0x25,
0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT
0x04, // 3.1.1
0xCE, 0x01, 0x23, // flags, keepalive
0x00, 0x03, 0x63, 0x6C, 0x69, // cli
0x00, 0x05, 0x74, 0x6F, 0x70, 0x69, 0x63, // topic
0x00, 0x01, 0x01, // payload
0x00, 0x04, 0x75, 0x73, 0x65, 0x72, // user
0x00, 0x04, 0x70, 0x61, 0x73, 0x73, // pass
}
if !bytes.Equal(expected, b[:n]) {
t.Fatalf("Expected CONNECT packet: \n '%v',\ngot: \n '%v'", expected, b[:n])
}
cli.Close()
}

func TestProtocolViolation(t *testing.T) {
ca, cb := net.Pipe()
cli := &BaseClient{Transport: cb}
Expand Down Expand Up @@ -112,19 +70,3 @@ func TestProtocolViolation(t *testing.T) {
t.Error("Timeout")
}
}

func TestConnect_OptionsError(t *testing.T) {
errExpected := errors.New("an error")
sessionPresent, err := (&BaseClient{}).Connect(
context.Background(), "cli",
func(*ConnectOptions) error {
return errExpected
},
)
if err != errExpected {
t.Errorf("Expected error: ''%v'', got: ''%v''", errExpected, err)
}
if sessionPresent {
t.Errorf("SessionPresent flag must not be set on options error")
}
}
113 changes: 69 additions & 44 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,82 +41,102 @@ const (
connectFlagUserName connectFlag = 0x80
)

// Connect to the broker.
func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) {
o := &ConnectOptions{
ProtocolLevel: ProtocolLevel4,
}
for _, opt := range opts {
if err := opt(o); err != nil {
return false, err
}
}
c.mu.Lock()
c.sig = &signaller{}
c.connClosed = make(chan struct{})
c.initID()
c.mu.Unlock()
type pktConnect struct {
ProtocolLevel ProtocolLevel
CleanSession bool
KeepAlive uint16
ClientID string
UserName string
Password string
Will *Message
}

go func() {
err := c.serve()
if errConn := c.Close(); errConn != nil && err == nil {
err = errConn
}
c.mu.Lock()
if c.connState != StateDisconnected {
c.err = err
}
c.mu.Unlock()
c.connStateUpdate(StateClosed)
close(c.connClosed)
}()
payload := packString(clientID)
func (p *pktConnect) pack() []byte {
payload := packString(p.ClientID)

var flag byte
if o.CleanSession {
if p.CleanSession {
flag |= byte(connectFlagCleanSession)
}
if o.Will != nil {
if p.Will != nil {
flag |= byte(connectFlagWill)
switch o.Will.QoS {
switch p.Will.QoS {
case QoS0:
flag |= byte(connectFlagWillQoS0)
case QoS1:
flag |= byte(connectFlagWillQoS1)
case QoS2:
flag |= byte(connectFlagWillQoS2)
default:
panic("invalid QoS")
}
if o.Will.Retain {
if p.Will.Retain {
flag |= byte(connectFlagWillRetain)
}
payload = append(payload, packString(o.Will.Topic)...)
payload = append(payload, packBytes(o.Will.Payload)...)
payload = append(payload, packString(p.Will.Topic)...)
payload = append(payload, packBytes(p.Will.Payload)...)
}
if o.UserName != "" {
if p.UserName != "" {
flag |= byte(connectFlagUserName)
payload = append(payload, packString(o.UserName)...)
payload = append(payload, packString(p.UserName)...)
}
if o.Password != "" {
if p.Password != "" {
flag |= byte(connectFlagPassword)
payload = append(payload, packString(o.Password)...)
payload = append(payload, packString(p.Password)...)
}
pkt := pack(
return pack(
packetConnect.b(),
[]byte{
0x00, 0x04, 0x4D, 0x51, 0x54, 0x54,
byte(o.ProtocolLevel),
byte(p.ProtocolLevel),
flag,
},
packUint16(o.KeepAlive),
packUint16(p.KeepAlive),
payload,
)
}

// Connect to the broker.
func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) {
o := &ConnectOptions{
ProtocolLevel: ProtocolLevel4,
}
for _, opt := range opts {
if err := opt(o); err != nil {
return false, err
}
}
c.sig = &signaller{}
c.connClosed = make(chan struct{})
c.initID()

go func() {
err := c.serve()
if errConn := c.Close(); errConn != nil && err == nil {
err = errConn
}
c.mu.Lock()
if c.connState != StateDisconnected {
c.err = err
}
c.mu.Unlock()
c.connStateUpdate(StateClosed)
close(c.connClosed)
}()

chConnAck := make(chan *pktConnAck, 1)
c.mu.Lock()
c.sig.chConnAck = chConnAck
c.mu.Unlock()

pkt := (&pktConnect{
ProtocolLevel: o.ProtocolLevel,
CleanSession: o.CleanSession,
KeepAlive: o.KeepAlive,
ClientID: clientID,
UserName: o.UserName,
Password: o.Password,
Will: o.Will,
}).pack()

if err := c.write(pkt); err != nil {
return false, err
}
Expand Down Expand Up @@ -175,6 +195,11 @@ func WithCleanSession(cleanSession bool) ConnectOption {
// WithWill sets will message.
func WithWill(will *Message) ConnectOption {
return func(o *ConnectOptions) error {
switch will.QoS {
case QoS0, QoS1, QoS2:
default:
return ErrInvalidPacket
}
o.Will = will
return nil
}
Expand Down
139 changes: 139 additions & 0 deletions connect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2019 The mqtt-go authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mqtt

import (
"bytes"
"context"
"errors"
"net"
"testing"
"time"
)

func TestConnect(t *testing.T) {
cases := map[string]struct {
opts []ConnectOption
expected []byte
}{
"UserPassCleanWill": {
opts: []ConnectOption{
WithUserNamePassword("user", "pass"),
WithKeepAlive(0x0123),
WithCleanSession(true),
WithProtocolLevel(ProtocolLevel4),
WithWill(&Message{QoS: QoS1, Topic: "topic", Payload: []byte{0x01}}),
},
expected: []byte{
0x10, // CONNECT
0x25,
0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT
0x04, // 3.1.1
0xCE, 0x01, 0x23, // flags, keepalive
0x00, 0x03, 0x63, 0x6C, 0x69, // cli
0x00, 0x05, 0x74, 0x6F, 0x70, 0x69, 0x63, // topic
0x00, 0x01, 0x01, // payload
0x00, 0x04, 0x75, 0x73, 0x65, 0x72, // user
0x00, 0x04, 0x70, 0x61, 0x73, 0x73, // pass
},
},
"WillQoS0": {
opts: []ConnectOption{
WithKeepAlive(0x0123),
WithWill(&Message{QoS: QoS0, Topic: "topic", Payload: []byte{0x01}}),
},
expected: []byte{
0x10, // CONNECT
0x19,
0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT
0x04, // 3.1.1
0x04, 0x01, 0x23, // flags, keepalive
0x00, 0x03, 0x63, 0x6C, 0x69, // cli
0x00, 0x05, 0x74, 0x6F, 0x70, 0x69, 0x63, // topic
0x00, 0x01, 0x01, // payload
},
},
"WillQoS2Retain": {
opts: []ConnectOption{
WithKeepAlive(0x0123),
WithWill(&Message{QoS: QoS2, Retain: true, Topic: "topic", Payload: []byte{0x01}}),
},
expected: []byte{
0x10, // CONNECT
0x19,
0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT
0x04, // 3.1.1
0x34, 0x01, 0x23, // flags, keepalive
0x00, 0x03, 0x63, 0x6C, 0x69, // cli
0x00, 0x05, 0x74, 0x6F, 0x70, 0x69, 0x63, // topic
0x00, 0x01, 0x01, // payload
},
},
"ProtocolLv3": {
opts: []ConnectOption{
WithKeepAlive(0x0123),
WithProtocolLevel(ProtocolLevel3),
},
expected: []byte{
0x10, // CONNECT
0x0F,
0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT
0x03, // 3.1.1
0x00, 0x01, 0x23, // flags, keepalive
0x00, 0x03, 0x63, 0x6C, 0x69, // cli
},
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
ca, cb := net.Pipe()
cli := &BaseClient{Transport: cb}

go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, _ = cli.Connect(ctx, "cli", c.opts...)
}()

b := make([]byte, 100)
n, err := ca.Read(b)
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if !bytes.Equal(c.expected, b[:n]) {
t.Fatalf("Expected CONNECT packet: \n '%v',\ngot: \n '%v'", c.expected, b[:n])
}
cli.Close()
})
}
}

func TestConnect_OptionsError(t *testing.T) {
errExpected := errors.New("an error")
sessionPresent, err := (&BaseClient{}).Connect(
context.Background(), "cli",
func(*ConnectOptions) error {
return errExpected
},
)
if err != errExpected {
t.Errorf("Expected error: ''%v'', got: ''%v''", errExpected, err)
}
if sessionPresent {
t.Errorf("SessionPresent flag must not be set on options error")
}
}
6 changes: 1 addition & 5 deletions disconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ import (

// Disconnect from the broker.
func (c *BaseClient) Disconnect(ctx context.Context) error {
pkt := pack(
packetDisconnect.b(),
[]byte{},
[]byte{},
)
pkt := pack(packetDisconnect.b())
c.connStateUpdate(StateDisconnected)
if err := c.write(pkt); err != nil {
return err
Expand Down
Loading

0 comments on commit ced04de

Please sign in to comment.