diff --git a/include/lwmqtt.h b/include/lwmqtt.h index bad8263..b28b485 100644 --- a/include/lwmqtt.h +++ b/include/lwmqtt.h @@ -79,10 +79,10 @@ typedef struct { typedef struct lwmqtt_client_t lwmqtt_client_t; /** - * The callback used to read from a network object. It may set read to zero if no data is available. + * The callback used to read from a network object. * - * The callback is expected to read the exact amount of bytes requested. It should wait up to the specified - * timeout to read the requested data from the network. + * The callbacks is expected to read up to the amount of bytes in to the passed buffer. It should block the specified + * timeout and wait for more incoming data. It may set read to zero if no data is has been read. */ typedef lwmqtt_err_t (*lwmqtt_network_read_t)(lwmqtt_client_t *c, void *ref, unsigned char *buf, int len, int *read, unsigned int timeout); @@ -90,8 +90,8 @@ typedef lwmqtt_err_t (*lwmqtt_network_read_t)(lwmqtt_client_t *c, void *ref, uns /** * The callback used to write to a network object. * - * The callback is expected to write the exact amount of bytes requested. If should wait up to the specified - * timeout to read write the specified data to the network. + * The callback is expected to write up to the amount of bytes from the passed buffer. It should wait up to the + * specified timeout to write the specified data to the network. */ typedef lwmqtt_err_t (*lwmqtt_network_write_t)(lwmqtt_client_t *c, void *ref, unsigned char *buf, int len, int *sent, unsigned int timeout); diff --git a/src/client.c b/src/client.c index f178639..5db66a8 100644 --- a/src/client.c +++ b/src/client.c @@ -52,15 +52,75 @@ static unsigned short lwmqtt_get_next_packet_id(lwmqtt_client_t *c) { return c->next_packet_id = (unsigned short)((c->next_packet_id == 65535) ? 1 : c->next_packet_id + 1); } +static lwmqtt_err_t lwmqtt_read_from_network(lwmqtt_client_t *c, int offset, int len) { + // check read buffer capacity + if (c->read_buf_size < offset + len) { + return LWMQTT_BUFFER_TOO_SHORT; + } + + // prepare counter + int read = 0; + + // read while data is missing + while (read < len) { + // get remaining time + unsigned int remaining_time = c->timer_get(c, c->command_timer); + + // check timeout + if (remaining_time <= 0) { + return LWMQTT_NOT_ENOUGH_DATA; + } + + // read + int partial_read = 0; + lwmqtt_err_t err = + c->network_read(c, c->network, c->read_buf + offset + read, len - read, &partial_read, remaining_time); + if (err != LWMQTT_SUCCESS) { + return err; + } + + // increment counter + read += partial_read; + } + + return LWMQTT_SUCCESS; +} + +static lwmqtt_err_t lwmqtt_write_to_network(lwmqtt_client_t *c, int offset, int len) { + // prepare counter + int written = 0; + + // write while data is left + while (written < len) { + // get remaining time + unsigned int remaining_time = c->timer_get(c, c->command_timer); + + // check timeout + if (remaining_time <= 0) { + return LWMQTT_NOT_ENOUGH_DATA; + } + + // read + int partial_write = 0; + lwmqtt_err_t err = + c->network_write(c, c->network, c->write_buf + offset + written, len - written, &partial_write, remaining_time); + if (err != LWMQTT_SUCCESS) { + return err; + } + + // increment counter + written += partial_write; + } + + return LWMQTT_SUCCESS; +} + static lwmqtt_err_t lwmqtt_read_packet_in_buffer(lwmqtt_client_t *c, int *read, lwmqtt_packet_type_t *packet_type) { // read header byte - int partial_read = 0; - lwmqtt_err_t err = c->network_read(c, c->network, c->read_buf, 1, &partial_read, c->timer_get(c, c->command_timer)); + lwmqtt_err_t err = lwmqtt_read_from_network(c, 0, 1); if (err != LWMQTT_SUCCESS) { - return err; - } else if (partial_read == 0) { *packet_type = LWMQTT_NO_PACKET; - return LWMQTT_SUCCESS; + return err; } // detect packet type @@ -78,12 +138,9 @@ static lwmqtt_err_t lwmqtt_read_packet_in_buffer(lwmqtt_client_t *c, int *read, len++; // read next byte - partial_read = 0; - err = c->network_read(c, c->network, c->read_buf + len, 1, &partial_read, c->timer_get(c, c->command_timer)); + err = lwmqtt_read_from_network(c, len, 1); if (err != LWMQTT_SUCCESS) { return err; - } else if (partial_read != 1) { - return LWMQTT_NOT_ENOUGH_DATA; } // attempt to detect remaining length @@ -97,18 +154,9 @@ static lwmqtt_err_t lwmqtt_read_packet_in_buffer(lwmqtt_client_t *c, int *read, // read the rest of the buffer if needed if (rem_len > 0) { - // check read buffer capacity - if (c->read_buf_size < 1 + len + rem_len) { - return LWMQTT_BUFFER_TOO_SHORT; - } - - partial_read = 0; - err = c->network_read(c, c->network, c->read_buf + 1 + len, rem_len, &partial_read, - c->timer_get(c, c->command_timer)); + err = lwmqtt_read_from_network(c, 1 + len, rem_len); if (err != LWMQTT_SUCCESS) { return err; - } else if (partial_read != rem_len) { - return LWMQTT_NOT_ENOUGH_DATA; } } @@ -120,17 +168,11 @@ static lwmqtt_err_t lwmqtt_read_packet_in_buffer(lwmqtt_client_t *c, int *read, static lwmqtt_err_t lwmqtt_send_packet_in_buffer(lwmqtt_client_t *c, int length) { // write to network - int sent = 0; - lwmqtt_err_t err = c->network_write(c, c->network, c->write_buf, length, &sent, c->timer_get(c, c->command_timer)); + lwmqtt_err_t err = lwmqtt_write_to_network(c, 0, length); if (err != LWMQTT_SUCCESS) { return err; } - // check length - if (sent != length) { - return LWMQTT_NOT_ENOUGH_DATA; - } - // reset keep alive timer c->timer_set(c, c->keep_alive_timer, c->keep_alive_interval * 1000); diff --git a/src/os/unix.c b/src/os/unix.c index 6b50fc6..3b8386f 100644 --- a/src/os/unix.c +++ b/src/os/unix.c @@ -134,25 +134,15 @@ lwmqtt_err_t lwmqtt_unix_network_read(lwmqtt_client_t *client, void *ref, unsign return LWMQTT_NETWORK_READ_ERROR; } - // loop until buffer is full - while (*read < len) { - // read from socket - int bytes = (int)recv(n->socket, &buffer[*read], (size_t)(len - *read), 0); - if (bytes < 0) { - // finish current loop on timeout - if (errno == EAGAIN) { - break; - } - - return LWMQTT_NETWORK_READ_ERROR; - } else if (bytes == 0) { - // finish if no more data - break; - } else - // increment counter - *read += bytes; + // read from socket + int bytes = (int)recv(n->socket, &buffer[*read], (size_t)(len - *read), 0); + if (bytes < 0 && errno != EAGAIN) { + return LWMQTT_NETWORK_READ_ERROR; } + // increment counter + *read += bytes; + return LWMQTT_SUCCESS; } @@ -168,17 +158,14 @@ lwmqtt_err_t lwmqtt_unix_network_write(lwmqtt_client_t *client, void *ref, unsig return LWMQTT_NETWORK_WRITE_ERR; } - // loop until all bytes haven been writen - while (*sent < len) { - // write to socket - int bytes = (int)send(n->socket, buffer, (size_t)len, 0); - if (bytes < 0) { - return LWMQTT_NETWORK_WRITE_ERR; - } else { - // increment counter - *sent += bytes; - } + // write to socket + int bytes = (int)send(n->socket, buffer, (size_t)len, 0); + if (bytes < 0 && errno != EAGAIN) { + return LWMQTT_NETWORK_WRITE_ERR; } + // increment counter + *sent += bytes; + return LWMQTT_SUCCESS; }