diff --git a/connack.go b/connack.go index 61c6152..3638283 100644 --- a/connack.go +++ b/connack.go @@ -34,19 +34,19 @@ const ( func (c ConnectionReturnCode) String() string { switch c { case ConnectionAccepted: - return "ConnectionAccepted" + return "connection accepted" case UnacceptableProtocolVersion: - return "Connection Refused, unacceptable protocol version" + return "connection refused, unacceptable protocol version" case IdentifierRejected: - return "Connection Refused, identifier rejected" + return "connection refused, identifier rejected" case ServerUnavailable: - return "Connection Refused, Server unavailable" + return "connection refused, Server unavailable" case BadUserNameOrPassword: - return "Connection Refused, bad user name or password" + return "connection refused, bad user name or password" case NotAuthorized: - return "Connection Refused, not authorized" + return "connection refused, not authorized" } - return fmt.Sprintf("Unknown ConnectionReturnCode %x", int(c)) + return fmt.Sprintf("unknown ConnectionReturnCode %x", int(c)) } type pktConnAck struct { diff --git a/connect.go b/connect.go index cc560d5..a8c9e8f 100644 --- a/connect.go +++ b/connect.go @@ -147,13 +147,34 @@ func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...Conne return false, ctx.Err() case connAck := <-chConnAck: if connAck.Code != ConnectionAccepted { - return false, errors.New(connAck.Code.String()) + return false, &ConnectionError{ + Err: ErrConnectionFailed, + Code: connAck.Code, + } } c.connStateUpdate(StateActive) return connAck.SessionPresent, nil } } +// ErrConnectionFailed means the connection is not established. +var ErrConnectionFailed = errors.New("connection failed") + +// ConnectionError ia a error storing connection return code. +type ConnectionError struct { + Err error + Code ConnectionReturnCode +} + +func (e *ConnectionError) Error() string { + return e.Code.String() + ": " + e.Err.Error() +} + +// Unwrap returns base error of ConnectionError. (for Go1.13 error unwrapping.) +func (e *ConnectionError) Unwrap() error { + return e.Err +} + // ConnectOptions represents options for Connect. type ConnectOptions struct { UserName string diff --git a/connect_test.go b/connect_test.go index d3fc0d1..ff09abd 100644 --- a/connect_test.go +++ b/connect_test.go @@ -137,3 +137,41 @@ func TestConnect_OptionsError(t *testing.T) { t.Errorf("SessionPresent flag must not be set on options error") } } + +func TestConnect_Error(t *testing.T) { + ca, cb := net.Pipe() + cli := &BaseClient{Transport: cb} + + go func() { + if _, err := ca.Read(make([]byte, 100)); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + // Send CONNACK. + if _, err := ca.Write([]byte{ + 0x20, 0x02, 0x00, 0x04, + }); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err := cli.Connect(ctx, "cli") + if err == nil { + t.Fatal("Error is not returned on connection refuse") + } + + conErr, ok := err.(*ConnectionError) + if !ok { + t.Fatal("Returned error type is not ConnectionError") + } + if conErr.Unwrap() != ErrConnectionFailed { + t.Errorf("Connection error must be unwrapped to: '%v', got: '%v'", + ErrConnectionFailed, conErr.Unwrap(), + ) + } + if conErr.Code != BadUserNameOrPassword { + t.Errorf("Server returned: '%v', parsed as: '%v'", BadUserNameOrPassword, conErr.Code) + } +} diff --git a/paho/paho.go b/paho/paho.go index c2c645f..e0e798c 100644 --- a/paho/paho.go +++ b/paho/paho.go @@ -24,7 +24,8 @@ import ( paho "github.com/eclipse/paho.mqtt.golang" ) -var errNotConnected = errors.New("not connected") +// ErrNotConnected means that the command was requested on the closed connection. +var ErrNotConnected = errors.New("not connected") type pahoWrapper struct { cli mqtt.Client @@ -229,7 +230,8 @@ func (c *pahoWrapper) Publish(topic string, qos byte, retained bool, payload int cli := c.cli c.mu.Unlock() if cli == nil { - token.err = errNotConnected + token.err = ErrNotConnected + token.release() return } @@ -253,7 +255,8 @@ func (c *pahoWrapper) Subscribe(topic string, qos byte, callback paho.MessageHan cli := c.cli c.mu.Unlock() if cli == nil { - token.err = errNotConnected + token.err = ErrNotConnected + token.release() return } @@ -286,7 +289,8 @@ func (c *pahoWrapper) SubscribeMultiple(filters map[string]byte, callback paho.M cli := c.cli c.mu.Unlock() if cli == nil { - token.err = errNotConnected + token.err = ErrNotConnected + token.release() return } @@ -303,7 +307,8 @@ func (c *pahoWrapper) Unsubscribe(topics ...string) paho.Token { cli := c.cli c.mu.Unlock() if cli == nil { - token.err = errNotConnected + token.err = ErrNotConnected + token.release() return } diff --git a/paho/paho_test.go b/paho/paho_test.go new file mode 100644 index 0000000..211c4c9 --- /dev/null +++ b/paho/paho_test.go @@ -0,0 +1,81 @@ +// Copyright 2019 The mqtt-go authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mqtt + +import ( + "net/url" + "testing" + "time" + + paho "github.com/eclipse/paho.mqtt.golang" +) + +func TestNotConnected(t *testing.T) { + cli := NewClient(&paho.ClientOptions{Servers: []*url.URL{{}}}) + + if ok := cli.IsConnected(); ok { + t.Error("IsConnected must return false on disconnected client.") + } + if ok := cli.IsConnectionOpen(); ok { + t.Error("IsConnectionOpen must return false on disconnected client.") + } + t.Run("Publish", func(t *testing.T) { + token := cli.Publish("a", 0, false, []byte{}) + if ok := token.WaitTimeout(time.Second); !ok { + t.Fatal("Timeout") + } + if token.Error() != ErrNotConnected { + t.Errorf("'%v' must be returned on disconnected client, got: '%v'", + ErrNotConnected, token.Error(), + ) + } + }) + t.Run("Subscribe", func(t *testing.T) { + token := cli.Subscribe("a", 0, func(paho.Client, paho.Message) {}) + if ok := token.WaitTimeout(time.Second); !ok { + t.Fatal("Timeout") + } + if token.Error() != ErrNotConnected { + t.Errorf("'%v' must be returned on disconnected client, got: '%v'", + ErrNotConnected, token.Error(), + ) + } + }) + t.Run("SubscribeMultiple", func(t *testing.T) { + token := cli.SubscribeMultiple( + map[string]byte{"a": 0}, + func(paho.Client, paho.Message) {}, + ) + if ok := token.WaitTimeout(time.Second); !ok { + t.Fatal("Timeout") + } + if token.Error() != ErrNotConnected { + t.Errorf("'%v' must be returned on disconnected client, got: '%v'", + ErrNotConnected, token.Error(), + ) + } + }) + t.Run("Unsubscribe", func(t *testing.T) { + token := cli.Unsubscribe("a") + if ok := token.WaitTimeout(time.Second); !ok { + t.Fatal("Timeout") + } + if token.Error() != ErrNotConnected { + t.Errorf("'%v' must be returned on disconnected client, got: '%v'", + ErrNotConnected, token.Error(), + ) + } + }) +}