diff --git a/ship/handshake.go b/ship/handshake.go index b68522ab..1dc2c553 100644 --- a/ship/handshake.go +++ b/ship/handshake.go @@ -145,19 +145,7 @@ func (c *ShipConnection) handleState(timeout bool, message []byte) { c.handshakeHello_PendingInit() case SmeHelloStatePendingListen: - if timeout { - // The device needs to be in a state for the user to allow trusting the device - // e.g. either the web UI or by other means - if !c.serviceDataProvider.AllowWaitingForTrust(c.remoteShipID) { - c.handshakeHello_PendingTimeout() - return - } - - c.handshakeHello_PendingProlongationRequest() - return - } - - c.handshakeHello_PendingListen(message) + c.handshakeHello_PendingListen(timeout, message) case SmeHelloStateOk: c.handshakeProtocol_Init() diff --git a/ship/hs_hello.go b/ship/hs_hello.go index be7a30d8..3c7c15a0 100644 --- a/ship/hs_hello.go +++ b/ship/hs_hello.go @@ -110,7 +110,19 @@ func (c *ShipConnection) handshakeHello_PendingInit() { } // SME_HELLO_PENDING_LISTEN -func (c *ShipConnection) handshakeHello_PendingListen(message []byte) { +func (c *ShipConnection) handshakeHello_PendingListen(timeout bool, message []byte) { + if timeout { + // The device needs to be in a state for the user to allow trusting the device + // e.g. either the web UI or by other means + if !c.serviceDataProvider.AllowWaitingForTrust(c.remoteShipID) { + c.handshakeHello_PendingTimeout() + } else { + c.handshakeHello_PendingProlongationRequest() + } + + return + } + var helloReturnMsg model.ConnectionHello if err := c.processShipJsonMessage(message, &helloReturnMsg); err != nil { c.setState(SmeHelloStateAbort, nil) diff --git a/ship/hs_hello_test.go b/ship/hs_hello_test.go index 4b22ec07..0babb1ee 100644 --- a/ship/hs_hello_test.go +++ b/ship/hs_hello_test.go @@ -171,7 +171,7 @@ func (s *HelloSuite) Test_PendingListen_Timeout() { time.Sleep(tHelloInit + time.Second) } else { // speed up the test by running the method directly - sut.handshakeHello_PendingTimeout() + sut.handshakeHello_PendingListen(true, nil) } assert.Equal(s.T(), SmeHelloStateAbortDone, sut.getState()) @@ -180,6 +180,23 @@ func (s *HelloSuite) Test_PendingListen_Timeout() { shutdownTest(sut) } +func (s *HelloSuite) Test_PendingListen_Timeout_Prolongation() { + sut, data := initTest(s.role) + + data.allowWaitingForTrust = true + + sut.setState(SmeHelloStatePendingInit, nil) // inits the timer + sut.setState(SmeHelloStatePendingListen, nil) + + // speed up the test by running the method directly, the timer is already checked + sut.handshakeHello_PendingListen(true, nil) + + assert.Equal(s.T(), SmeHelloStatePendingListen, sut.getState()) + assert.NotNil(s.T(), data.lastMessage()) + + shutdownTest(sut) +} + func (s *HelloSuite) Test_PendingListen_ReadyAbort() { sut, data := initTest(s.role) diff --git a/ship/hs_helper_test.go b/ship/hs_helper_test.go index 10f2a8ef..ca86f32d 100644 --- a/ship/hs_helper_test.go +++ b/ship/hs_helper_test.go @@ -13,6 +13,8 @@ type dataHandlerTest struct { mux sync.Mutex + allowWaitingForTrust bool + handleConnectionClosedInvoked bool } @@ -45,8 +47,10 @@ func (s *dataHandlerTest) IsRemoteServiceForSKIPaired(string) bool { return true func (s *dataHandlerTest) HandleConnectionClosed(*ShipConnection, bool) { s.handleConnectionClosedInvoked = true } -func (s *dataHandlerTest) ReportServiceShipID(string, string) {} -func (s *dataHandlerTest) AllowWaitingForTrust(string) bool { return false } +func (s *dataHandlerTest) ReportServiceShipID(string, string) {} +func (s *dataHandlerTest) AllowWaitingForTrust(string) bool { + return s.allowWaitingForTrust +} func (s *dataHandlerTest) HandleShipHandshakeStateUpdate(string, ShipState) {} func initTest(role shipRole) (*ShipConnection, *dataHandlerTest) {