diff --git a/ship/hs_prot.go b/ship/hs_prot.go index bcacf6f7..b1b58568 100644 --- a/ship/hs_prot.go +++ b/ship/hs_prot.go @@ -112,38 +112,38 @@ func (c *ShipConnection) handshakeProtocol_smeProtHStateClientListenChoice(messa msgHandshake := messageProtocolHandshake.MessageProtocolHandshake + abort := false if msgHandshake.HandshakeType != model.ProtocolHandshakeTypeTypeSelect { logging.Log.Debug("invalid protocol handshake response") - c.abortProtocolHandshake(model.MessageProtocolHandshakeErrorErrorTypeSelectionMismatch) - return + abort = true } if msgHandshake.Version.Major != 1 { logging.Log.Debug("unsupported protocol major version") - c.abortProtocolHandshake(model.MessageProtocolHandshakeErrorErrorTypeSelectionMismatch) - return + abort = true } if msgHandshake.Version.Minor != 0 { logging.Log.Debug("unsupported protocol minor version") - c.abortProtocolHandshake(model.MessageProtocolHandshakeErrorErrorTypeSelectionMismatch) - return + abort = true } if msgHandshake.Formats.Format == nil || len(msgHandshake.Formats.Format) == 0 { logging.Log.Debug("format is missing") - c.abortProtocolHandshake(model.MessageProtocolHandshakeErrorErrorTypeSelectionMismatch) - return + abort = true } if len(msgHandshake.Formats.Format) != 1 { logging.Log.Debug("unsupported format response") - c.abortProtocolHandshake(model.MessageProtocolHandshakeErrorErrorTypeSelectionMismatch) - return + abort = true } - if msgHandshake.Formats.Format[0] != model.MessageProtocolFormatTypeUTF8 { + if msgHandshake.Formats.Format != nil && msgHandshake.Formats.Format[0] != model.MessageProtocolFormatTypeUTF8 { logging.Log.Debug("unsupported format") + abort = true + } + + if abort { c.abortProtocolHandshake(model.MessageProtocolHandshakeErrorErrorTypeSelectionMismatch) return } @@ -171,5 +171,7 @@ func (c *ShipConnection) abortProtocolHandshake(err model.MessageProtocolHandsha _ = c.sendShipModel(model.MsgTypeControl, msg) + c.setState(SmeStateError, errors.New("handshake error")) + c.CloseConnection(false, 0, "") } diff --git a/ship/hs_prot_client_test.go b/ship/hs_prot_client_test.go index 826d37c9..dd8d074c 100644 --- a/ship/hs_prot_client_test.go +++ b/ship/hs_prot_client_test.go @@ -65,3 +65,63 @@ func (s *ProClientSuite) Test_ListenChoice() { shutdownTest(sut) } + +func (s *ProClientSuite) Test_ListenChoice_Failures() { + sut, data := initTest(s.role) + + sut.setState(SmeProtHStateClientListenChoice, nil) + + protMsg := model.MessageProtocolHandshake{ + MessageProtocolHandshake: model.MessageProtocolHandshakeType{ + HandshakeType: model.ProtocolHandshakeTypeTypeAnnounceMax, + Version: model.Version{Major: 0, Minor: 1}, + }, + } + + msg, err := sut.shipMessage(model.MsgTypeControl, protMsg) + assert.Nil(s.T(), err) + assert.NotNil(s.T(), msg) + + sut.handleState(false, msg) + + sut.setState(SmeProtHStateClientListenChoice, nil) + + protMsg = model.MessageProtocolHandshake{ + MessageProtocolHandshake: model.MessageProtocolHandshakeType{ + HandshakeType: model.ProtocolHandshakeTypeTypeAnnounceMax, + Version: model.Version{Major: 0, Minor: 1}, + Formats: model.MessageProtocolFormatsType{ + Format: []model.MessageProtocolFormatType{model.MessageProtocolFormatTypeUTF16}, + }, + }, + } + + msg, err = sut.shipMessage(model.MsgTypeControl, protMsg) + assert.Nil(s.T(), err) + assert.NotNil(s.T(), msg) + + sut.handleState(false, msg) + + assert.Equal(s.T(), false, sut.handshakeTimerRunning) + + assert.Equal(s.T(), SmeStateError, sut.getState()) + assert.NotNil(s.T(), data.lastMessage()) + + shutdownTest(sut) +} + +func (s *ProClientSuite) Test_Abort() { + sut, data := initTest(s.role) + + sut.setState(SmeProtHStateClientListenChoice, nil) + + sut.abortProtocolHandshake(model.MessageProtocolHandshakeErrorErrorTypeTimeout) + + assert.Equal(s.T(), SmeStateError, sut.getState()) + assert.NotNil(s.T(), data.lastMessage()) + + timer := sut.getHandshakeTimerRunnging() + assert.Equal(s.T(), false, timer) + + shutdownTest(sut) +} diff --git a/ship/hs_prot_server_test.go b/ship/hs_prot_server_test.go index 78fca7eb..a42fdb0a 100644 --- a/ship/hs_prot_server_test.go +++ b/ship/hs_prot_server_test.go @@ -66,6 +66,30 @@ func (s *ProServerSuite) Test_ListenProposal() { shutdownTest(sut) } +func (s *ProServerSuite) Test_ListenProposal_Failure() { + sut, _ := initTest(s.role) + + sut.setState(SmeProtHStateServerListenProposal, nil) + + protMsg := model.MessageProtocolHandshake{ + MessageProtocolHandshake: model.MessageProtocolHandshakeType{ + HandshakeType: model.ProtocolHandshakeTypeTypeSelect, + }, + } + + msg, err := sut.shipMessage(model.MsgTypeControl, protMsg) + assert.Nil(s.T(), err) + assert.NotNil(s.T(), msg) + + sut.handleState(false, msg) + + assert.Equal(s.T(), false, sut.handshakeTimerRunning) + + assert.Equal(s.T(), SmeStateError, sut.getState()) + + shutdownTest(sut) +} + func (s *ProServerSuite) Test_ListenConfirm() { sut, data := initTest(s.role) @@ -95,3 +119,28 @@ func (s *ProServerSuite) Test_ListenConfirm() { shutdownTest(sut) } + +func (s *ProServerSuite) Test_ListenConfirm_Failures() { + sut, data := initTest(s.role) + + sut.setState(SmeProtHStateServerListenConfirm, nil) + + protMsg := model.MessageProtocolHandshake{ + MessageProtocolHandshake: model.MessageProtocolHandshakeType{ + HandshakeType: model.ProtocolHandshakeTypeTypeAnnounceMax, + }, + } + + msg, err := sut.shipMessage(model.MsgTypeControl, protMsg) + assert.Nil(s.T(), err) + assert.NotNil(s.T(), msg) + + sut.handleState(false, msg) + + assert.Equal(s.T(), false, sut.handshakeTimerRunning) + + assert.Equal(s.T(), SmeStateError, sut.getState()) + assert.NotNil(s.T(), data.lastMessage()) + + shutdownTest(sut) +}