diff --git a/communication/inc/dtls_message_channel.h b/communication/inc/dtls_message_channel.h index db64bed5a3..0278a9a222 100644 --- a/communication/inc/dtls_message_channel.h +++ b/communication/inc/dtls_message_channel.h @@ -109,8 +109,8 @@ class DTLSMessageChannel: public BufferMessageChannel /** * C function to call the send/recv methods on a DTLSMessageChannel instance. */ - static int send_(void* ctx, const uint8_t* data, size_t len); - static int recv_(void* ctx, uint8_t* data, size_t len); + static int sendCallback(void* ctx, const uint8_t* data, size_t len); + static int recvCallback(void* ctx, uint8_t* data, size_t len); int send(const uint8_t* data, size_t len); int recv(uint8_t* data, size_t len); diff --git a/communication/inc/protocol_defs.h b/communication/inc/protocol_defs.h index 438b6600c3..b690c5d38e 100644 --- a/communication/inc/protocol_defs.h +++ b/communication/inc/protocol_defs.h @@ -57,6 +57,9 @@ enum ProtocolError NO_MEMORY = 30, INTERNAL = 31, OTA_UPDATE_ERROR = 32, // Generic OTA update error + IO_ERROR_SOCKET_SEND_FAILED = 33, + IO_ERROR_SOCKET_RECV_FAILED = 34, + IO_ERROR_REMOTE_END_CLOSED = 35, // NOTE: when adding more ProtocolError codes, be sure to update toSystemError() in protocol_defs.cpp UNKNOWN = 0x7FFFF }; diff --git a/communication/src/dtls_message_channel.cpp b/communication/src/dtls_message_channel.cpp index 3674cfd854..b2c3af81c9 100644 --- a/communication/src/dtls_message_channel.cpp +++ b/communication/src/dtls_message_channel.cpp @@ -39,6 +39,7 @@ void mbedtls_ssl_update_out_pointers(mbedtls_ssl_context *ssl, mbedtls_ssl_trans #include "protocol.h" #include "rng_hal.h" +#include "mbedtls/net_sockets.h" #include "mbedtls/error.h" #include "mbedtls/ssl_internal.h" #include "mbedtls_util.h" @@ -282,32 +283,35 @@ inline int DTLSMessageChannel::recv(uint8_t* data, size_t len) { int size = callbacks.receive(data, len, callbacks.tx_context); // ignore 0 and 1 byte UDP packets which are used to keep alive the connection. - if (size>=0 && size <=1) + if (size == 1) { size = 0; + } return size; } -int DTLSMessageChannel::send_(void *ctx, const unsigned char *buf, size_t len ) { +int DTLSMessageChannel::sendCallback(void *ctx, const unsigned char *buf, size_t len ) { DTLSMessageChannel* channel = (DTLSMessageChannel*)ctx; int count = channel->send(buf, len); - if (count == 0) + if (count == 0) { return MBEDTLS_ERR_SSL_WANT_WRITE; - + } else if (count < 0) { + return MBEDTLS_ERR_NET_SEND_FAILED; + } return count; } -int DTLSMessageChannel::recv_( void *ctx, unsigned char *buf, size_t len ) { +int DTLSMessageChannel::recvCallback( void *ctx, unsigned char *buf, size_t len ) { DTLSMessageChannel* channel = (DTLSMessageChannel*)ctx; - int count = channel->recv(buf, len); if (count == 0) { // 0 means no more data available yet return MBEDTLS_ERR_SSL_WANT_READ; + } else if (count < 0) { + return MBEDTLS_ERR_NET_RECV_FAILED; } return count; } - void DTLSMessageChannel::init() { server_public = nullptr; @@ -344,7 +348,7 @@ ProtocolError DTLSMessageChannel::setup_context() EXIT_ERROR(ret, "unable to setup SSL context"); mbedtls_ssl_set_timer_cb(&ssl_context, &timer, mbedtls_timing_set_delay, mbedtls_timing_get_delay); - mbedtls_ssl_set_bio(&ssl_context, this, &DTLSMessageChannel::send_, &DTLSMessageChannel::recv_, NULL); + mbedtls_ssl_set_bio(&ssl_context, this, &DTLSMessageChannel::sendCallback, &DTLSMessageChannel::recvCallback, NULL); if ((ssl_context.session_negotiate->peer_cert = (mbedtls_x509_crt*)calloc(1, sizeof(mbedtls_x509_crt))) == NULL) { @@ -458,22 +462,30 @@ ProtocolError DTLSMessageChannel::receive(Message& message) conf.read_timeout = 0; int ret = mbedtls_ssl_read(&ssl_context, buf, len); if (ret < 0) { - switch (ret) { - case MBEDTLS_ERR_SSL_WANT_READ: - case MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE: + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE) { ret = 0; - break; - case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: - command(CLOSE); - break; - default: - reset_session(); - return IO_ERROR_GENERIC_RECEIVE; + } else { + LOG(ERROR, "mbedtls_ssl_read() failed: -0x%x", -ret); + switch (ret) { + // mbedtls_ssl_read() may need to flush the output before attempting to read from the socket + // so we need to handle both MBEDTLS_ERR_NET_SEND_FAILED and MBEDTLS_ERR_NET_RECV_FAILED here + case MBEDTLS_ERR_NET_SEND_FAILED: + case MBEDTLS_ERR_NET_RECV_FAILED: + // Do not invalidate the session on network errors + return IO_ERROR_SOCKET_RECV_FAILED; + case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: + command(CLOSE); + return IO_ERROR_REMOTE_END_CLOSED; + default: + reset_session(); + return IO_ERROR_GENERIC_RECEIVE; + } } } message.set_length(ret); if (ret > 0) { cancel_move_session(); + sessionPersist.update(&ssl_context, callbacks.save, coap_state ? *coap_state : 0); #if defined(DEBUG_BUILD) && 0 if (LOG_ENABLED(TRACE)) { LOG(TRACE, "msg len %u", (unsigned)message.length()); @@ -482,7 +494,6 @@ ProtocolError DTLSMessageChannel::receive(Message& message) } #endif } - sessionPersist.update(&ssl_context, callbacks.save, coap_state ? *coap_state : 0); return NO_ERROR; } @@ -520,7 +531,11 @@ ProtocolError DTLSMessageChannel::send(Message& message) int ret = mbedtls_ssl_write(&ssl_context, message.buf(), message.length()); if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { - LOG(WARN, "mbedtls_ssl_write returned -0x%x", -ret); + LOG(ERROR, "mbedtls_ssl_write() failed: -0x%x", -ret); + if (ret == MBEDTLS_ERR_NET_SEND_FAILED) { + // Do not invalidate the session on network errors + return IO_ERROR_SOCKET_SEND_FAILED; + } reset_session(); return IO_ERROR_GENERIC_MBEDTLS_SSL_WRITE; } diff --git a/communication/src/protocol_defs.cpp b/communication/src/protocol_defs.cpp index e268547c8e..31ca21c88c 100644 --- a/communication/src/protocol_defs.cpp +++ b/communication/src/protocol_defs.cpp @@ -48,6 +48,8 @@ system_error_t toSystemError(ProtocolError error) { case IO_ERROR_LIGHTSSL_RECEIVE: case IO_ERROR_LIGHTSSL_HANDSHAKE_NONCE: case IO_ERROR_LIGHTSSL_HANDSHAKE_RECV_KEY: + case IO_ERROR_SOCKET_SEND_FAILED: + case IO_ERROR_SOCKET_RECV_FAILED: return SYSTEM_ERROR_IO; case INVALID_STATE: return SYSTEM_ERROR_INVALID_STATE; @@ -67,6 +69,8 @@ system_error_t toSystemError(ProtocolError error) { return SYSTEM_ERROR_INTERNAL; case OTA_UPDATE_ERROR: return SYSTEM_ERROR_OTA; + case IO_ERROR_REMOTE_END_CLOSED: + return SYSTEM_ERROR_END_OF_STREAM; default: return SYSTEM_ERROR_PROTOCOL; // Generic protocol error } diff --git a/system/inc/active_object.h b/system/inc/active_object.h index 821fb1cd42..63d705e380 100644 --- a/system/inc/active_object.h +++ b/system/inc/active_object.h @@ -490,6 +490,13 @@ class ISRTaskQueue { struct Task { TaskFunc func; Task* next; // Next element in the queue + + explicit Task(TaskFunc func = nullptr) : + func(func), + next(nullptr) { + } + + virtual ~Task() = default; }; ISRTaskQueue() : diff --git a/system/inc/system_threading.h b/system/inc/system_threading.h index 2cd2692df4..340681002f 100644 --- a/system/inc/system_threading.h +++ b/system/inc/system_threading.h @@ -20,6 +20,10 @@ #define SYSTEM_THREADING_H #include "active_object.h" +#include "system_error.h" + +#include +#include #if PLATFORM_THREADING @@ -130,8 +134,48 @@ os_mutex_recursive_t mutex_usb_serial(); namespace particle { +namespace detail { + +struct CallableTaskBase: ISRTaskQueue::Task { + virtual void call() = 0; +}; + +template +struct CallableTask: CallableTaskBase { + F fn; + + explicit CallableTask(F&& fn) : fn(std::move(fn)) { + } + + void call() override { + fn(); + } +}; + +} // namespace detail + extern ISRTaskQueue SystemISRTaskQueue; +/** + * Asynchronously invokes a function in the context of the system thread via the ISR task queue. + * + * @note This function allocates memory and thus it cannot be called from an ISR. + */ +template +int invokeAsync(F&& fn) { + // Not using std::function here as it's not exception-safe + auto task = new(std::nothrow) detail::CallableTask(std::move(fn)); + if (!task) { + return SYSTEM_ERROR_NO_MEMORY; + } + task->func = [](ISRTaskQueue::Task* task) { + static_cast(task)->call(); + delete task; + }; + SystemISRTaskQueue.enqueue(task); + return 0; +} + } // namespace particle #endif /* SYSTEM_THREADING_H */ diff --git a/system/src/system_listening_mode.cpp b/system/src/system_listening_mode.cpp index a3069afdda..3f153fa92b 100644 --- a/system/src/system_listening_mode.cpp +++ b/system/src/system_listening_mode.cpp @@ -171,7 +171,6 @@ int ListeningModeHandler::enqueueCommand(network_listen_command_t com, void* arg return SYSTEM_ERROR_NO_MEMORY; } - memset(task, 0, sizeof(Task)); task->command = com; task->arg = arg; task->func = reinterpret_cast(&executeEnqueuedCommand); diff --git a/system/src/system_network_compat.cpp b/system/src/system_network_compat.cpp index a5f570f461..29ea8de193 100644 --- a/system/src/system_network_compat.cpp +++ b/system/src/system_network_compat.cpp @@ -68,27 +68,37 @@ inline NetworkInterface& nif(network_interface_t _nif) { return cellular; } void HAL_WLAN_notify_simple_config_done() { - network.notify_listening_complete(); + invokeAsync([]() { + network.notify_listening_complete(); + }); } void HAL_NET_notify_connected() { - network.notify_connected(); + invokeAsync([]() { + network.notify_connected(); + }); } void HAL_NET_notify_disconnected() { - network.notify_disconnected(); + invokeAsync([]() { + network.notify_disconnected(); + }); } void HAL_NET_notify_error() { - network.notify_error(); + invokeAsync([]() { + network.notify_error(); + }); } void HAL_NET_notify_dhcp(bool dhcp) { - network.notify_dhcp(dhcp); + invokeAsync([dhcp]() { + network.notify_dhcp(dhcp); + }); } const void* network_config(network_handle_t network, uint32_t param, void* reserved)