Skip to content

Commit

Permalink
Merge pull request #2335 from particle-iot/socket_errors/ch82443
Browse files Browse the repository at this point in the history
Do not reset the session on socket errors
  • Loading branch information
avtolstoy authored Jul 8, 2021
2 parents b611090 + ec8c271 commit c4c5236
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 28 deletions.
4 changes: 2 additions & 2 deletions communication/inc/dtls_message_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>
/**
* C function to call the send/recv methods on a DTLSMessageChannel instance.
*/
static int send_(void* ctx, const uint8_t* data, size_t len);
static int recv_(void* ctx, uint8_t* data, size_t len);
static int sendCallback(void* ctx, const uint8_t* data, size_t len);
static int recvCallback(void* ctx, uint8_t* data, size_t len);

int send(const uint8_t* data, size_t len);
int recv(uint8_t* data, size_t len);
Expand Down
3 changes: 3 additions & 0 deletions communication/inc/protocol_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ enum ProtocolError
NO_MEMORY = 30,
INTERNAL = 31,
OTA_UPDATE_ERROR = 32, // Generic OTA update error
IO_ERROR_SOCKET_SEND_FAILED = 33,
IO_ERROR_SOCKET_RECV_FAILED = 34,
IO_ERROR_REMOTE_END_CLOSED = 35,
// NOTE: when adding more ProtocolError codes, be sure to update toSystemError() in protocol_defs.cpp
UNKNOWN = 0x7FFFF
};
Expand Down
55 changes: 35 additions & 20 deletions communication/src/dtls_message_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void mbedtls_ssl_update_out_pointers(mbedtls_ssl_context *ssl, mbedtls_ssl_trans

#include "protocol.h"
#include "rng_hal.h"
#include "mbedtls/net_sockets.h"
#include "mbedtls/error.h"
#include "mbedtls/ssl_internal.h"
#include "mbedtls_util.h"
Expand Down Expand Up @@ -282,32 +283,35 @@ inline int DTLSMessageChannel::recv(uint8_t* data, size_t len)
{
int size = callbacks.receive(data, len, callbacks.tx_context);
// ignore 0 and 1 byte UDP packets which are used to keep alive the connection.
if (size>=0 && size <=1)
if (size == 1) {
size = 0;
}
return size;
}

int DTLSMessageChannel::send_(void *ctx, const unsigned char *buf, size_t len ) {
int DTLSMessageChannel::sendCallback(void *ctx, const unsigned char *buf, size_t len ) {
DTLSMessageChannel* channel = (DTLSMessageChannel*)ctx;
int count = channel->send(buf, len);
if (count == 0)
if (count == 0) {
return MBEDTLS_ERR_SSL_WANT_WRITE;

} else if (count < 0) {
return MBEDTLS_ERR_NET_SEND_FAILED;
}
return count;
}

int DTLSMessageChannel::recv_( void *ctx, unsigned char *buf, size_t len ) {
int DTLSMessageChannel::recvCallback( void *ctx, unsigned char *buf, size_t len ) {
DTLSMessageChannel* channel = (DTLSMessageChannel*)ctx;

int count = channel->recv(buf, len);
if (count == 0) {
// 0 means no more data available yet
return MBEDTLS_ERR_SSL_WANT_READ;
} else if (count < 0) {
return MBEDTLS_ERR_NET_RECV_FAILED;
}
return count;
}


void DTLSMessageChannel::init()
{
server_public = nullptr;
Expand Down Expand Up @@ -344,7 +348,7 @@ ProtocolError DTLSMessageChannel::setup_context()
EXIT_ERROR(ret, "unable to setup SSL context");

mbedtls_ssl_set_timer_cb(&ssl_context, &timer, mbedtls_timing_set_delay, mbedtls_timing_get_delay);
mbedtls_ssl_set_bio(&ssl_context, this, &DTLSMessageChannel::send_, &DTLSMessageChannel::recv_, NULL);
mbedtls_ssl_set_bio(&ssl_context, this, &DTLSMessageChannel::sendCallback, &DTLSMessageChannel::recvCallback, NULL);

if ((ssl_context.session_negotiate->peer_cert = (mbedtls_x509_crt*)calloc(1, sizeof(mbedtls_x509_crt))) == NULL)
{
Expand Down Expand Up @@ -458,22 +462,30 @@ ProtocolError DTLSMessageChannel::receive(Message& message)
conf.read_timeout = 0;
int ret = mbedtls_ssl_read(&ssl_context, buf, len);
if (ret < 0) {
switch (ret) {
case MBEDTLS_ERR_SSL_WANT_READ:
case MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE:
if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE) {
ret = 0;
break;
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
command(CLOSE);
break;
default:
reset_session();
return IO_ERROR_GENERIC_RECEIVE;
} else {
LOG(ERROR, "mbedtls_ssl_read() failed: -0x%x", -ret);
switch (ret) {
// mbedtls_ssl_read() may need to flush the output before attempting to read from the socket
// so we need to handle both MBEDTLS_ERR_NET_SEND_FAILED and MBEDTLS_ERR_NET_RECV_FAILED here
case MBEDTLS_ERR_NET_SEND_FAILED:
case MBEDTLS_ERR_NET_RECV_FAILED:
// Do not invalidate the session on network errors
return IO_ERROR_SOCKET_RECV_FAILED;
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
command(CLOSE);
return IO_ERROR_REMOTE_END_CLOSED;
default:
reset_session();
return IO_ERROR_GENERIC_RECEIVE;
}
}
}
message.set_length(ret);
if (ret > 0) {
cancel_move_session();
sessionPersist.update(&ssl_context, callbacks.save, coap_state ? *coap_state : 0);
#if defined(DEBUG_BUILD) && 0
if (LOG_ENABLED(TRACE)) {
LOG(TRACE, "msg len %u", (unsigned)message.length());
Expand All @@ -482,7 +494,6 @@ ProtocolError DTLSMessageChannel::receive(Message& message)
}
#endif
}
sessionPersist.update(&ssl_context, callbacks.save, coap_state ? *coap_state : 0);
return NO_ERROR;
}

Expand Down Expand Up @@ -520,7 +531,11 @@ ProtocolError DTLSMessageChannel::send(Message& message)

int ret = mbedtls_ssl_write(&ssl_context, message.buf(), message.length());
if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
LOG(WARN, "mbedtls_ssl_write returned -0x%x", -ret);
LOG(ERROR, "mbedtls_ssl_write() failed: -0x%x", -ret);
if (ret == MBEDTLS_ERR_NET_SEND_FAILED) {
// Do not invalidate the session on network errors
return IO_ERROR_SOCKET_SEND_FAILED;
}
reset_session();
return IO_ERROR_GENERIC_MBEDTLS_SSL_WRITE;
}
Expand Down
4 changes: 4 additions & 0 deletions communication/src/protocol_defs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ system_error_t toSystemError(ProtocolError error) {
case IO_ERROR_LIGHTSSL_RECEIVE:
case IO_ERROR_LIGHTSSL_HANDSHAKE_NONCE:
case IO_ERROR_LIGHTSSL_HANDSHAKE_RECV_KEY:
case IO_ERROR_SOCKET_SEND_FAILED:
case IO_ERROR_SOCKET_RECV_FAILED:
return SYSTEM_ERROR_IO;
case INVALID_STATE:
return SYSTEM_ERROR_INVALID_STATE;
Expand All @@ -67,6 +69,8 @@ system_error_t toSystemError(ProtocolError error) {
return SYSTEM_ERROR_INTERNAL;
case OTA_UPDATE_ERROR:
return SYSTEM_ERROR_OTA;
case IO_ERROR_REMOTE_END_CLOSED:
return SYSTEM_ERROR_END_OF_STREAM;
default:
return SYSTEM_ERROR_PROTOCOL; // Generic protocol error
}
Expand Down
7 changes: 7 additions & 0 deletions system/inc/active_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,13 @@ class ISRTaskQueue {
struct Task {
TaskFunc func;
Task* next; // Next element in the queue

explicit Task(TaskFunc func = nullptr) :
func(func),
next(nullptr) {
}

virtual ~Task() = default;
};

ISRTaskQueue() :
Expand Down
44 changes: 44 additions & 0 deletions system/inc/system_threading.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#define SYSTEM_THREADING_H

#include "active_object.h"
#include "system_error.h"

#include <utility>
#include <new>

#if PLATFORM_THREADING

Expand Down Expand Up @@ -130,8 +134,48 @@ os_mutex_recursive_t mutex_usb_serial();

namespace particle {

namespace detail {

struct CallableTaskBase: ISRTaskQueue::Task {
virtual void call() = 0;
};

template<typename F>
struct CallableTask: CallableTaskBase {
F fn;

explicit CallableTask(F&& fn) : fn(std::move(fn)) {
}

void call() override {
fn();
}
};

} // namespace detail

extern ISRTaskQueue SystemISRTaskQueue;

/**
* Asynchronously invokes a function in the context of the system thread via the ISR task queue.
*
* @note This function allocates memory and thus it cannot be called from an ISR.
*/
template<typename F>
int invokeAsync(F&& fn) {
// Not using std::function here as it's not exception-safe
auto task = new(std::nothrow) detail::CallableTask<F>(std::move(fn));
if (!task) {
return SYSTEM_ERROR_NO_MEMORY;
}
task->func = [](ISRTaskQueue::Task* task) {
static_cast<detail::CallableTaskBase*>(task)->call();
delete task;
};
SystemISRTaskQueue.enqueue(task);
return 0;
}

} // namespace particle

#endif /* SYSTEM_THREADING_H */
1 change: 0 additions & 1 deletion system/src/system_listening_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ int ListeningModeHandler::enqueueCommand(network_listen_command_t com, void* arg
return SYSTEM_ERROR_NO_MEMORY;
}

memset(task, 0, sizeof(Task));
task->command = com;
task->arg = arg;
task->func = reinterpret_cast<ISRTaskQueue::TaskFunc>(&executeEnqueuedCommand);
Expand Down
20 changes: 15 additions & 5 deletions system/src/system_network_compat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,37 @@ inline NetworkInterface& nif(network_interface_t _nif) { return cellular; }

void HAL_WLAN_notify_simple_config_done()
{
network.notify_listening_complete();
invokeAsync([]() {
network.notify_listening_complete();
});
}

void HAL_NET_notify_connected()
{
network.notify_connected();
invokeAsync([]() {
network.notify_connected();
});
}

void HAL_NET_notify_disconnected()
{
network.notify_disconnected();
invokeAsync([]() {
network.notify_disconnected();
});
}

void HAL_NET_notify_error()
{
network.notify_error();
invokeAsync([]() {
network.notify_error();
});
}

void HAL_NET_notify_dhcp(bool dhcp)
{
network.notify_dhcp(dhcp);
invokeAsync([dhcp]() {
network.notify_dhcp(dhcp);
});
}

const void* network_config(network_handle_t network, uint32_t param, void* reserved)
Expand Down

0 comments on commit c4c5236

Please sign in to comment.