diff --git a/src/transport/PeerConnections.cpp b/src/transport/PeerConnections.cpp deleted file mode 100644 index b3a0dad8fc054a..00000000000000 --- a/src/transport/PeerConnections.cpp +++ /dev/null @@ -1,114 +0,0 @@ -/* - * - * Copyright (c) 2020 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -namespace chip { -namespace Transport { - -CHIP_ERROR PeerConnectionsBase::CreateNewPeerConnectionState(const PeerAddress & address, PeerConnectionState ** state) -{ - CHIP_ERROR err = CHIP_ERROR_NO_MEMORY; - - if (state) - { - *state = nullptr; - } - - for (size_t i = 0; i < mConnectionStateArraySize; i++) - { - if (!mConnectionStateArray[i].GetPeerAddress().IsInitialized()) - { - mConnectionStateArray[i] = PeerConnectionState(address); - mConnectionStateArray[i].SetLastActivityTimeMs(GetCurrentMonotonicTimeMs()); - - if (state) - { - *state = &mConnectionStateArray[i]; - } - - err = CHIP_NO_ERROR; - break; - } - } - - return err; -} - -bool PeerConnectionsBase::FindPeerConnectionState(const PeerAddress & address, PeerConnectionState ** state) -{ - *state = nullptr; - for (size_t i = 0; i < mConnectionStateArraySize; i++) - { - if (mConnectionStateArray[i].GetPeerAddress() == address) - { - *state = &mConnectionStateArray[i]; - break; - } - } - return *state != nullptr; -} - -bool PeerConnectionsBase::FindPeerConnectionState(NodeId nodeId, PeerConnectionState ** state) -{ - *state = nullptr; - for (size_t i = 0; i < mConnectionStateArraySize; i++) - { - if (!mConnectionStateArray[i].GetPeerAddress().IsInitialized()) - { - continue; - } - if (mConnectionStateArray[i].GetPeerNodeId() == nodeId) - { - *state = &mConnectionStateArray[i]; - break; - } - } - return *state != nullptr; -} - -void PeerConnectionsBase::ExpireInactiveConnections(uint64_t maxIdleTimeMs) -{ - const uint64_t currentTime = GetCurrentMonotonicTimeMs(); - - for (size_t i = 0; i < mConnectionStateArraySize; i++) - { - if (!mConnectionStateArray[i].GetPeerAddress().IsInitialized()) - { - continue; // not an active connection - } - - uint64_t connectionActiveTime = mConnectionStateArray[i].GetLastActivityTimeMs(); - if (connectionActiveTime + maxIdleTimeMs >= currentTime) - { - continue; // not expired - } - - if (OnConnectionExpired) - { - OnConnectionExpired(mConnectionStateArray[i], mConnectionExpiredArgument); - } - - // Connection is assumed expired, marking it as invalid - mConnectionStateArray[i].Reset(); - } -} - -} // namespace Transport -} // namespace chip diff --git a/src/transport/PeerConnections.h b/src/transport/PeerConnections.h index ec7b2b515a913e..4b5ed043358303 100644 --- a/src/transport/PeerConnections.h +++ b/src/transport/PeerConnections.h @@ -17,8 +17,8 @@ #ifndef PEER_CONNECTIONS_H_ #define PEER_CONNECTIONS_H_ -#include #include +#include #include #include @@ -32,17 +32,10 @@ namespace Transport { * - handle connection active time and expiration * - allocate and free space for connection states. */ -class PeerConnectionsBase +template +class PeerConnections { public: - /** - * Construct a PeerConnectionsBase object using a preallocated array used for connection state storage. - */ - PeerConnectionsBase(PeerConnectionState * storageArray, size_t arraySize) : - mConnectionStateArray(storageArray), mConnectionStateArraySize(arraySize) - {} - virtual ~PeerConnectionsBase() {} - /** * Allocates a new peer connection state state object out of the internal resource pool. * @@ -54,17 +47,58 @@ class PeerConnectionsBase * @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum connection count * has been reached (with CHIP_ERROR_NO_MEMORY). */ - CHIP_ERROR CreateNewPeerConnectionState(const PeerAddress & address, PeerConnectionState ** state); + CHECK_RETURN_VALUE + CHIP_ERROR CreateNewPeerConnectionState(const PeerAddress & address, PeerConnectionState ** state) + { + CHIP_ERROR err = CHIP_ERROR_NO_MEMORY; + + if (state) + { + *state = nullptr; + } + + for (size_t i = 0; i < kMaxConnectionCount; i++) + { + if (!mStates[i].GetPeerAddress().IsInitialized()) + { + mStates[i] = PeerConnectionState(address); + mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); + + if (state) + { + *state = &mStates[i]; + } + + err = CHIP_NO_ERROR; + break; + } + } + + return err; + } /** - * Get a peer connection state given a peer address. + * Get a peer connection state given a Peer address. * * @param address is the connection to find (based on address) * @param state [out] the connection if found, null otherwise. MUST not be null. * * @return true if a corresponding state was found. */ - bool FindPeerConnectionState(const PeerAddress & address, PeerConnectionState ** state); + CHECK_RETURN_VALUE + bool FindPeerConnectionState(const PeerAddress & address, PeerConnectionState ** state) + { + *state = nullptr; + for (size_t i = 0; i < kMaxConnectionCount; i++) + { + if (mStates[i].GetPeerAddress() == address) + { + *state = &mStates[i]; + break; + } + } + return *state != nullptr; + } /** * Get a peer connection state given a Node Id. @@ -75,10 +109,30 @@ class PeerConnectionsBase * * @return true if a corresponding state was found. */ - bool FindPeerConnectionState(NodeId nodeId, PeerConnectionState ** state); + CHECK_RETURN_VALUE + bool FindPeerConnectionState(NodeId nodeId, PeerConnectionState ** state) + { + *state = nullptr; + for (size_t i = 0; i < kMaxConnectionCount; i++) + { + if (!mStates[i].GetPeerAddress().IsInitialized()) + { + continue; + } + if (mStates[i].GetPeerNodeId() == nodeId) + { + *state = &mStates[i]; + break; + } + } + return *state != nullptr; + } /// Convenience method to mark a peer connection state as active - void MarkConnectionActive(PeerConnectionState * state) { state->SetLastActivityTimeMs(GetCurrentMonotonicTimeMs()); } + void MarkConnectionActive(PeerConnectionState * state) + { + state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); + } /** * Iterates through all active connections and expires any connection with an idle time @@ -86,7 +140,35 @@ class PeerConnectionsBase * * Expiring a connection involves callback execution and then clearing the internal state. */ - void ExpireInactiveConnections(uint64_t maxIdleTimeMs); + void ExpireInactiveConnections(uint64_t maxIdleTimeMs) + { + const uint64_t currentTime = mTimeSource.GetCurrentMonotonicTimeMs(); + + for (size_t i = 0; i < kMaxConnectionCount; i++) + { + if (!mStates[i].GetPeerAddress().IsInitialized()) + { + continue; // not an active connection + } + + uint64_t connectionActiveTime = mStates[i].GetLastActivityTimeMs(); + if (connectionActiveTime + maxIdleTimeMs >= currentTime) + { + continue; // not expired + } + + if (OnConnectionExpired) + { + OnConnectionExpired(mStates[i], mConnectionExpiredArgument); + } + + // Connection is assumed expired, marking it as invalid + mStates[i] = PeerConnectionState(PeerAddress::Uninitialized()); + } + } + + /// Allows access to the underlying time source used for keeping track of connection active time + Time::TimeSource & GetTimeSource() { return mTimeSource; } /** * Sets the handler for expired connections @@ -102,36 +184,16 @@ class PeerConnectionsBase OnConnectionExpired = reinterpret_cast(handler); } -protected: - /// Get the current time from a Time::TimeSource or equivalent - virtual uint64_t GetCurrentMonotonicTimeMs() = 0; - private: - PeerConnectionState * mConnectionStateArray; - const size_t mConnectionStateArraySize; + Time::TimeSource mTimeSource; + PeerConnectionState mStates[kMaxConnectionCount]; typedef void (*ConnectionExpiredHandler)(const PeerConnectionState & state, void * param); - ConnectionExpiredHandler OnConnectionExpired = nullptr; ///< Callback for when a connection expires + ConnectionExpiredHandler OnConnectionExpired = nullptr; ///< Callback for connection expiry void * mConnectionExpiredArgument = nullptr; ///< Argument for callback }; -/** - * Concrete peer connections implementation based on system sizes and timers. - */ -class PeerConnections : public PeerConnectionsBase -{ -public: - PeerConnections() : PeerConnectionsBase(mState, sizeof(mState) / sizeof((mState)[0])) {} - -protected: - uint64_t GetCurrentMonotonicTimeMs() override { return mTimeSource.GetCurrentMonotonicTimeMs(); } - -private: - PeerConnectionState mState[CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE]; - Time::TimeSource mTimeSource; -}; - } // namespace Transport } // namespace chip diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index 11d9e355b6572f..c35fea086783e7 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -130,9 +130,9 @@ class DLL_EXPORT SecureSessionMgr : public ReferenceCounted // TODO: add support for multiple transports (TCP, BLE to be added) Transport::UDP mTransport; - NodeId mLocalNodeId; //< Id of the current node - Transport::PeerConnections mPeerConnections; //< Active connections to other peers - State mState; //< Initialization state of the object + NodeId mLocalNodeId; //< Id of the current node + Transport::PeerConnections mPeerConnections; //< Active connections to other peers + State mState; //< Initialization state of the object /** * This function is the application callback that is invoked when a message is received over a diff --git a/src/transport/TransportLayer.am b/src/transport/TransportLayer.am index 30dbea2a1d9134..325efb3d572c9a 100644 --- a/src/transport/TransportLayer.am +++ b/src/transport/TransportLayer.am @@ -27,7 +27,6 @@ CHIP_BUILD_TRANSPORT_LAYER_SOURCE_FILES = \ @top_builddir@/src/transport/SecureSession.cpp \ @top_builddir@/src/transport/MessageHeader.cpp \ - @top_builddir@/src/transport/PeerConnections.cpp \ @top_builddir@/src/transport/SecureSessionMgr.cpp \ @top_builddir@/src/transport/UDP.cpp \ $(NULL) diff --git a/src/transport/tests/TestPeerConnections.cpp b/src/transport/tests/TestPeerConnections.cpp index 5d4d711443baf4..6a49031e94a919 100644 --- a/src/transport/tests/TestPeerConnections.cpp +++ b/src/transport/tests/TestPeerConnections.cpp @@ -53,26 +53,11 @@ const NodeId kPeer1NodeId = 123; const NodeId kPeer2NodeId = 6; const NodeId kPeer3NodeId = 81; -/// A Peer connections that supports exactly 2 connections and a test time source. -class TestPeerConnections : public PeerConnectionsBase -{ -public: - TestPeerConnections() : PeerConnectionsBase(mState, ArraySize(mState)) {} - Time::TimeSource & GetTimeSource() { return mTimeSource; } - -protected: - uint64_t GetCurrentMonotonicTimeMs() override { return mTimeSource.GetCurrentMonotonicTimeMs(); } - -private: - PeerConnectionState mState[2]; - Time::TimeSource mTimeSource; -}; - void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext) { CHIP_ERROR err; PeerConnectionState * statePtr; - TestPeerConnections connections; + PeerConnections<2, Time::Source::kTest> connections; connections.GetTimeSource().SetCurrentMonotonicTimeMs(100); err = connections.CreateNewPeerConnectionState(kPeer1Addr, nullptr); @@ -93,7 +78,7 @@ void TestFindByAddress(nlTestSuite * inSuite, void * inContext) { CHIP_ERROR err; PeerConnectionState * statePtr; - TestPeerConnections connections; + PeerConnections<2, Time::Source::kTest> connections; err = connections.CreateNewPeerConnectionState(kPeer1Addr, nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); @@ -114,7 +99,7 @@ void TestFindByNodeId(nlTestSuite * inSuite, void * inContext) { CHIP_ERROR err; PeerConnectionState * statePtr; - TestPeerConnections connections; + PeerConnections<2, Time::Source::kTest> connections; err = connections.CreateNewPeerConnectionState(kPeer1Addr, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); @@ -155,7 +140,7 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) CHIP_ERROR err; ExpiredCallInfo callInfo; PeerConnectionState * statePtr; - TestPeerConnections connections; + PeerConnections<2, Time::Source::kTest> connections; connections.SetConnectionExpiredHandler(OnConnectionExpired, &callInfo);