From 09ded2b5e5bd34bbcf0fd71b5482381cf7f08627 Mon Sep 17 00:00:00 2001 From: Steve Kim <86316075+sbSteveK@users.noreply.github.com> Date: Thu, 23 Feb 2023 14:23:05 -0800 Subject: [PATCH] Secure tunnel with Multiplexing (#78) * refactor and update of Secure Tunnel and its API * Support for V2 WebSocket Protocol (Multiplexing) --- include/aws/iotdevice/iotdevice.h | 17 +- .../iotdevice/private/iotdevice_internals.h | 43 - .../iotdevice/private/secure_tunneling_impl.h | 253 +- .../private/secure_tunneling_operations.h | 202 ++ include/aws/iotdevice/private/serializer.h | 66 +- include/aws/iotdevice/secure_tunneling.h | 274 +- source/iotdevice.c | 60 +- source/secure_tunneling.c | 2208 +++++++++++++---- source/secure_tunneling_operations.c | 722 ++++++ source/serializer.c | 453 +++- tests/CMakeLists.txt | 30 +- tests/aws_iot_secure_tunneling_client_test.c | 199 -- tests/secure_tunnel_tests.c | 1123 +++++++++ tests/secure_tunneling_tests.c | 717 ------ tests/tests_protobuf/aws_iot_st_pb_test.cpp | 4 +- 15 files changed, 4676 insertions(+), 1695 deletions(-) delete mode 100644 include/aws/iotdevice/private/iotdevice_internals.h create mode 100644 include/aws/iotdevice/private/secure_tunneling_operations.h create mode 100644 source/secure_tunneling_operations.c delete mode 100644 tests/aws_iot_secure_tunneling_client_test.c create mode 100644 tests/secure_tunnel_tests.c delete mode 100644 tests/secure_tunneling_tests.c diff --git a/include/aws/iotdevice/iotdevice.h b/include/aws/iotdevice/iotdevice.h index ffe1374b..1f18507c 100644 --- a/include/aws/iotdevice/iotdevice.h +++ b/include/aws/iotdevice/iotdevice.h @@ -21,8 +21,21 @@ enum aws_iotdevice_error { AWS_ERROR_IOTDEVICE_DEFENDER_PUBLISH_FAILURE, AWS_ERROR_IOTDEVICE_DEFENDER_UNKNOWN_TASK_STATUS, - AWS_ERROR_IOTDEVICE_SECUTRE_TUNNELING_INVALID_STREAM, - AWS_ERROR_IOTDEVICE_SECUTRE_TUNNELING_INCORRECT_MODE, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INCORRECT_MODE, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_BAD_SERVICE_ID, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_STREAM_OPTIONS_VALIDATION, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_SECURE_TUNNEL_TERMINATED, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_WEBSOCKET_TIMEOUT, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_PING_RESPONSE_TIMEOUT, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_FAILED_DUE_TO_DISCONNECTION, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_PROCESSING_FAILURE, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_FAILED_DUE_TO_OFFLINE_QUEUE_POLICY, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_UNEXPECTED_HANGUP, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_TERMINATED, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DECODE_FAILURE, AWS_ERROR_END_IOTDEVICE_RANGE = AWS_ERROR_ENUM_END_RANGE(AWS_C_IOTDEVICE_PACKAGE_ID), }; diff --git a/include/aws/iotdevice/private/iotdevice_internals.h b/include/aws/iotdevice/private/iotdevice_internals.h deleted file mode 100644 index 0239108e..00000000 --- a/include/aws/iotdevice/private/iotdevice_internals.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ -#ifndef AWS_IOTDEVICE_IOTDEVICE_INTERNALS_H -#define AWS_IOTDEVICE_IOTDEVICE_INTERNALS_H - -#include - -struct aws_byte_buf; -struct aws_byte_cursor; -struct aws_secure_tunnel; -struct aws_websocket; -struct aws_websocket_client_connection_options; -struct aws_websocket_send_frame_options; - -AWS_EXTERN_C_BEGIN - -AWS_IOTDEVICE_API -int secure_tunneling_init_send_frame( - struct aws_websocket_send_frame_options *frame_options, - struct aws_secure_tunnel *secure_tunnel, - const struct aws_byte_cursor *data, - enum aws_iot_st_message_type type); - -AWS_IOTDEVICE_API -void init_websocket_client_connection_options( - struct aws_secure_tunnel *secure_tunnel, - struct aws_websocket_client_connection_options *websocket_options); - -AWS_IOTDEVICE_API -int secure_tunneling_init_send_frame( - struct aws_websocket_send_frame_options *frame_options, - struct aws_secure_tunnel *secure_tunnel, - const struct aws_byte_cursor *data, - enum aws_iot_st_message_type type); - -AWS_IOTDEVICE_API -bool secure_tunneling_send_data_call(struct aws_websocket *websocket, struct aws_byte_buf *out_buf, void *user_data); - -AWS_EXTERN_C_END - -#endif /* AWS_IOTDEVICE_IOTDEVICE_INTERNALS_H */ diff --git a/include/aws/iotdevice/private/secure_tunneling_impl.h b/include/aws/iotdevice/private/secure_tunneling_impl.h index 50fbf507..4cbfd8fb 100644 --- a/include/aws/iotdevice/private/secure_tunneling_impl.h +++ b/include/aws/iotdevice/private/secure_tunneling_impl.h @@ -8,60 +8,265 @@ #include #include +#include #include +#include +#include +#include +#include #include +/** + * The various states that the secure tunnel can be in. A secure tunnel has both a current state and a desired state. + * Desired state is only allowed to be one of {STOPPED, CONNECTED, TERMINATED}. The secure tunnel transitions states + * based on either + * (1) changes in desired state, or + * (2) external events. + * + * Most states are interruptible (in the sense of a change in desired state causing an immediate change in state) but + * CONNECTING cannot be interrupted due to waiting for an asynchronous callback (that has no + * cancel) to complete. + */ +enum aws_secure_tunnel_state { + /* + * The secure tunnel is not connected and not waiting for anything to happen. + * + * Next States: + * CONNECTING - if the user invokes Connect() on the secure tunnel + * TERMINATED - if the user releases the last ref count on the secure tunnel + */ + AWS_STS_STOPPED, + + /* + * The secure tunnel is attempting to connect to a remote endpoint and establish a WebSocket upgrade. This state is + * not interruptible by any means other than WebSocket setup completion. + * + * Next States: + * CONNECTED - if WebSocket handshake is successful and desired state is still CONNECTED + * WEBSOCKET_SHUTDOWN - if the WebSocket completes setup with no error but desired state is not CONNECTED + * PENDING_RECONNECT - if the WebSocket fails to complete setup and desired state is still CONNECTED + * STOPPED - if the WebSocket fails to complete setup and desired state is not CONNECTED + */ + AWS_STS_CONNECTING, + + /* + * The secure tunnel is ready to perform user-requested operations. + * + * Next States: + * WEBSOCKET_SHUTDOWN - desired state is no longer CONNECTED + * PENDING_RECONNECT - unexpected WebSocket shutdown completion and desired state still CONNECTED + * STOPPED - unexpected WebSocket shutdown completion and desired state no longer CONNECTED + */ + AWS_STS_CONNECTED, + + /* + * The secure tunnel is attempting to shut down a WebSocket connection cleanly by finishing the current operation + * and then transmitting a STREAM RESET message to all open streams. + * + * Next States: + * WEBSOCKET_SHUTDOWN - on sucessful (or unsuccessful) disconnection + * PENDING_RECONNECT - unexpected WebSocket shutdown completion and desired state still CONNECTED + * STOPPED - unexpected WebSocket shutdown completion and desired state no longer CONNECTED + */ + AWS_STS_CLEAN_DISCONNECT, + + /* + * The secure tunnel is waiting for the WebSocket to completely shut down. This state is not interruptible. + * + * Next States: + * PENDING_RECONNECT - the WebSocket has shut down and desired state is still CONNECTED + * STOPPED - the WebSocket has shut down and desired state is not CONNECTED + */ + AWS_STS_WEBSOCKET_SHUTDOWN, + + /* + * The secure tunnel is waiting for the reconnect timer to expire before attempting to connect again. + * + * Next States: + * CONNECTING - the reconnect timer has expired and desired state is still CONNECTED + * STOPPED - desired state is no longer CONNECTED + */ + AWS_STS_PENDING_RECONNECT, + + /* + * The secure tunnel is performing final shutdown and release of all resources. This state is only realized for + * a non-observable instant of time (transition out of STOPPED). + */ + AWS_STS_TERMINATED +}; + struct data_tunnel_pair { + struct aws_allocator *allocator; struct aws_byte_buf buf; struct aws_byte_cursor cur; const struct aws_secure_tunnel *secure_tunnel; bool length_prefix_written; }; -struct aws_secure_tunnel_vtable { - int (*connect)(struct aws_secure_tunnel *secure_tunnel); - int (*send_data)(struct aws_secure_tunnel *secure_tunnel, const struct aws_byte_cursor *data); - int (*send_stream_start)(struct aws_secure_tunnel *secure_tunnel); - int (*send_stream_reset)(struct aws_secure_tunnel *secure_tunnel); - int (*close)(struct aws_secure_tunnel *secure_tunnel); +/* + * Secure tunnel configuration + */ +struct aws_secure_tunnel_options_storage { + struct aws_allocator *allocator; + + /* backup */ + + struct aws_client_bootstrap *bootstrap; + struct aws_socket_options socket_options; + struct aws_http_proxy_options http_proxy_options; + struct aws_http_proxy_config *http_proxy_config; + struct aws_string *access_token; + struct aws_string *client_token; + + struct aws_string *endpoint_host; + + /* Stream related info */ + int32_t stream_id; + + struct aws_hash_table service_ids; + + /* Callbacks */ + aws_secure_tunnel_message_received_fn *on_message_received; + aws_secure_tunneling_on_connection_complete_fn *on_connection_complete; + aws_secure_tunneling_on_connection_shutdown_fn *on_connection_shutdown; + aws_secure_tunneling_on_stream_start_fn *on_stream_start; + aws_secure_tunneling_on_stream_reset_fn *on_stream_reset; + aws_secure_tunneling_on_session_reset_fn *on_session_reset; + aws_secure_tunneling_on_stopped_fn *on_stopped; + + aws_secure_tunneling_on_send_data_complete_fn *on_send_data_complete; + aws_secure_tunneling_on_termination_complete_fn *on_termination_complete; + void *secure_tunnel_on_termination_user_data; + + void *user_data; + enum aws_secure_tunneling_local_proxy_mode local_proxy_mode; }; -struct aws_websocket_client_connection_options; -struct aws_websocket_send_frame_options; +struct aws_secure_tunnel_vtable { + /* aws_high_res_clock_get_ticks */ + uint64_t (*get_current_time_fn)(void); + + /* For test verification */ + int (*aws_websocket_client_connect_fn)(const struct aws_websocket_client_connection_options *options); + + /* For test verification */ + int (*aws_websocket_send_frame_fn)( + struct aws_websocket *websocket, + const struct aws_websocket_send_frame_options *options); -struct aws_websocket_vtable { - int (*client_connect)(const struct aws_websocket_client_connection_options *options); - int (*send_frame)(struct aws_websocket *websocket, const struct aws_websocket_send_frame_options *options); - void (*close)(struct aws_websocket *websocket, bool free_scarce_resources_immediately); - void (*release)(struct aws_websocket *websocket); + /* For test verification */ + void (*aws_websocket_release_fn)(struct aws_websocket *websocket); + + /* For test verification */ + void (*aws_websocket_close_fn)(struct aws_websocket *websocket, bool free_scarce_resources_immediately); + + void *vtable_user_data; }; struct aws_secure_tunnel { /* Static settings */ - struct aws_allocator *alloc; - struct aws_secure_tunnel_options_storage *options_storage; - struct aws_secure_tunnel_options *options; + struct aws_allocator *allocator; + struct aws_ref_count ref_count; + + const struct aws_secure_tunnel_vtable *vtable; + + /* + * Secure tunnel configuration + */ + struct aws_secure_tunnel_options_storage *config; + struct aws_tls_ctx *tls_ctx; struct aws_tls_connection_options tls_con_opt; - struct aws_secure_tunnel_vtable vtable; - struct aws_websocket_vtable websocket_vtable; - struct aws_ref_count ref_count; + /* + * The recurrent task that runs all secure tunnel logic outside of external event callbacks. Bound to the secure + * tunnel's event loop. + */ + struct aws_task service_task; + + /* + * Tracks when the secure tunnel's service task is next schedule to run. Is zero if the task is not scheduled to + * run or we are in the middle of a service (so technically not scheduled too). + */ + uint64_t next_service_task_run_time; + + /* + * True if the secure tunnel's service task is running. Used to skip service task reevaluation due to state changes + * while running the service task. Reevaluation will occur at the very end of the service. + */ + bool in_service; + + /* + * Event loop all the secure tunnel's tasks will be pinned to, ensuring serialization and + * concurrency safety. + */ + struct aws_event_loop *loop; + + /* + * What state is the secure tunnel working towards? + */ + enum aws_secure_tunnel_state desired_state; + + /* + * What is the secure tunnel's current state? + */ + enum aws_secure_tunnel_state current_state; - /* Used only during initial websocket setup. Otherwise, should be NULL */ + /* + * handshake_request exists between the transform completion timepoint and the websocket setup callback. + */ struct aws_http_message *handshake_request; /* Dynamic data */ - int32_t stream_id; + struct aws_websocket *websocket; /* Stores what has been received but not processed */ struct aws_byte_buf received_data; - /* The secure tunneling endpoint ELB drops idle connect after 1 minute. We need to send a ping periodically to keep - * the connection */ + /* + * When should the secure tunnel next attempt to reconnect? Only used by PENDING_RECONNECT state. + */ + uint64_t next_reconnect_time_ns; + + /* + * How many consecutive reconnect failures have we experienced? + */ + uint64_t reconnect_count; - struct ping_task_context *ping_task_context; + struct aws_linked_list queued_operations; + struct aws_secure_tunnel_operation *current_operation; + + /* + * Is there a WebSocket message in transit (to the socket) that has not invoked its write completion callback yet? + * The secure tunnel implementation only allows one in-transit message at a time, and so if this is true, we don't + * send additional ones/ + */ + bool pending_write_completion; + + /* + * When should the next PINGREQ be sent? + * The secure tunneling endpoint ELB drops idle connect after 1 minute. we need to send a ping periodically to keep + * the connection alive. + */ + uint64_t next_ping_time; }; +AWS_EXTERN_C_BEGIN + +/* + * Override the vtable used by the secure tunnel; useful for mocking certain scenarios. + */ +AWS_IOTDEVICE_API void aws_secure_tunnel_set_vtable( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_vtable *vtable); + +/* + * Gets the default vtable used by the secure tunnel. In order to mock something, we start with the default and then + * mutate it selectively to achieve the scenario we're interested in. + */ +AWS_IOTDEVICE_API const struct aws_secure_tunnel_vtable *aws_secure_tunnel_get_default_vtable(void); + +AWS_EXTERN_C_END + #endif /* AWS_IOTDEVICE_SECURE_TUNNELING_IMPL_H */ diff --git a/include/aws/iotdevice/private/secure_tunneling_operations.h b/include/aws/iotdevice/private/secure_tunneling_operations.h new file mode 100644 index 00000000..6196351d --- /dev/null +++ b/include/aws/iotdevice/private/secure_tunneling_operations.h @@ -0,0 +1,202 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#ifndef AWS_IOTDEVICE_SECURE_TUNNELING_OPERATION_H +#define AWS_IOTDEVICE_SECURE_TUNNELING_OPERATION_H + +#include +#include +#include +#include +#include + +/********************************************************************************************************************* + * Operations + ********************************************************************************************************************/ + +struct aws_secure_tunnel_operation; + +enum aws_secure_tunnel_operation_type { + AWS_STOT_NONE, + AWS_STOT_PING, + AWS_STOT_MESSAGE, + AWS_STOT_STREAM_RESET, + AWS_STOT_STREAM_START +}; + +struct aws_service_id_element { + struct aws_allocator *allocator; + struct aws_byte_cursor service_id_cur; + struct aws_string *service_id_string; + int32_t stream_id; +}; + +struct aws_secure_tunnel_message_storage { + struct aws_allocator *allocator; + struct aws_secure_tunnel_message_view storage_view; + + bool ignorable; + int32_t stream_id; + struct aws_byte_cursor service_id; + struct aws_byte_cursor payload; + + struct aws_byte_buf storage; +}; + +/* Basic vtable for all secure tunnel operations. Implementations are per-message type */ +struct aws_secure_tunnel_operation_vtable { + void (*aws_secure_tunnel_operation_completion_fn)( + struct aws_secure_tunnel_operation *operation, + int error_code, + const void *completion_view); + + /* Set the stream id of outgoing st_msg based on current service id */ + int (*aws_secure_tunnel_operation_assign_stream_id_fn)( + struct aws_secure_tunnel_operation *operation, + struct aws_secure_tunnel *secure_tunnel); + + /* Set the stream id of outgoing st_msg to +1 of the currently set stream id */ + int (*aws_secure_tunnel_operation_set_next_stream_id_fn)( + struct aws_secure_tunnel_operation *operation, + struct aws_secure_tunnel *secure_tunnel); +}; + +/** + * This is the base structure for all secure tunnel operations. It includes the type, a ref count, and list management. + */ +struct aws_secure_tunnel_operation { + const struct aws_secure_tunnel_operation_vtable *vtable; + struct aws_ref_count ref_count; + struct aws_linked_list_node node; + + enum aws_secure_tunnel_operation_type operation_type; + const struct aws_secure_tunnel_message_view *message_view; + + /* Size of the secure tunnel message this operation represents */ + size_t message_size; + + void *impl; +}; + +struct aws_secure_tunnel_operation_message { + struct aws_secure_tunnel_operation base; + struct aws_allocator *allocator; + + struct aws_secure_tunnel_message_storage options_storage; +}; + +struct aws_secure_tunnel_operation_pingreq { + struct aws_secure_tunnel_operation base; + struct aws_allocator *allocator; +}; + +AWS_EXTERN_C_BEGIN + +/* Operation Base */ + +AWS_IOTDEVICE_API struct aws_secure_tunnel_operation *aws_secure_tunnel_operation_acquire( + struct aws_secure_tunnel_operation *operation); + +AWS_IOTDEVICE_API struct aws_secure_tunnel_operation *aws_secure_tunnel_operation_release( + struct aws_secure_tunnel_operation *operation); + +AWS_IOTDEVICE_API void aws_secure_tunnel_operation_complete( + struct aws_secure_tunnel_operation *operation, + int error_code, + const void *associated_view); + +AWS_IOTDEVICE_API void aws_secure_tunnel_operation_assign_stream_id( + struct aws_secure_tunnel_operation *operation, + struct aws_secure_tunnel *secure_tunnel); + +AWS_IOTDEVICE_API int32_t + aws_secure_tunnel_operation_get_stream_id(const struct aws_secure_tunnel_operation *operation); + +AWS_IOTDEVICE_API int32_t *aws_secure_tunnel_operation_get_stream_id_address( + const struct aws_secure_tunnel_operation *operation); + +/* Message */ + +AWS_IOTDEVICE_API +int aws_secure_tunnel_message_view_validate(const struct aws_secure_tunnel_message_view *message_view); + +AWS_IOTDEVICE_API +void aws_secure_tunnel_message_view_log( + const struct aws_secure_tunnel_message_view *message_view, + enum aws_log_level level); + +AWS_IOTDEVICE_API +int aws_secure_tunnel_message_storage_init( + struct aws_secure_tunnel_message_storage *message_storage, + struct aws_allocator *allocator, + const struct aws_secure_tunnel_message_view *message_options, + enum aws_secure_tunnel_operation_type type); + +AWS_IOTDEVICE_API +void aws_secure_tunnel_message_storage_clean_up(struct aws_secure_tunnel_message_storage *message_storage); + +AWS_IOTDEVICE_API +struct aws_secure_tunnel_operation_message *aws_secure_tunnel_operation_message_new( + struct aws_allocator *allocator, + const struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options, + enum aws_secure_tunnel_operation_type type); + +/* Ping */ + +AWS_IOTDEVICE_API +struct aws_secure_tunnel_operation_pingreq *aws_secure_tunnel_operation_pingreq_new(struct aws_allocator *allocator); + +/* Secure Tunnel Storage Options */ + +/** + * Raises exception and returns AWS_OP_ERR if options are missing required parameters. + */ +AWS_IOTDEVICE_API +int aws_secure_tunnel_options_validate(const struct aws_secure_tunnel_options *options); + +/** + * Destroy options storage, and release any references held. + */ +AWS_IOTDEVICE_API +void aws_secure_tunnel_options_storage_destroy(struct aws_secure_tunnel_options_storage *storage); + +/** + * Create persistent storage for aws_secure_tunnel_options. + * Makes a deep copy of (or acquires reference to) any data referenced by options, + */ +AWS_IOTDEVICE_API +struct aws_secure_tunnel_options_storage *aws_secure_tunnel_options_storage_new( + struct aws_allocator *allocator, + const struct aws_secure_tunnel_options *options); + +AWS_IOTDEVICE_API +void aws_secure_tunnel_options_storage_log( + const struct aws_secure_tunnel_options_storage *options_storage, + enum aws_log_level level); + +AWS_IOTDEVICE_API +const char *aws_secure_tunnel_operation_type_to_c_string(enum aws_secure_tunnel_operation_type operation_type); + +/* Data Tunnel Pair */ + +AWS_IOTDEVICE_API +void aws_secure_tunnel_data_tunnel_pair_destroy(struct data_tunnel_pair *pair); + +AWS_IOTDEVICE_API +struct data_tunnel_pair *aws_secure_tunnel_data_tunnel_pair_new( + struct aws_allocator *allocator, + const struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_view); + +AWS_IOTDEVICE_API +struct aws_service_id_element *aws_service_id_element_new( + struct aws_allocator *allocator, + const struct aws_byte_cursor *service_id, + int32_t stream_id); + +AWS_EXTERN_C_END + +#endif /* AWS_IOTDEVICE_SECURE_TUNNELING_OPERATION_H */ diff --git a/include/aws/iotdevice/private/serializer.h b/include/aws/iotdevice/private/serializer.h index 94fe0681..306f800c 100644 --- a/include/aws/iotdevice/private/serializer.h +++ b/include/aws/iotdevice/private/serializer.h @@ -6,54 +6,58 @@ #define AWS_IOTDEVICE_SERIALIZER_H #include +#include +#include #include -#define AWS_IOT_ST_MESSAGE_TYPEFIELD 1 -#define AWS_IOT_ST_MESSAGE_STREAM_ID 2 -#define AWS_IOT_ST_MESSAGE_IGNORABLE 3 -#define AWS_IOT_ST_MESSAGE_PAYLOAD 4 -#define AWS_IOT_ST_VARINT_WIRE 0 -#define AWS_IOT_ST_VARINT_LENGTHDELIM_WIRE 2 - #define AWS_IOT_ST_FIELD_NUMBER_SHIFT 3 -#define AWS_IOT_ST_MESSAGE_DEFAULT_STREAM_ID 0 -#define AWS_IOT_ST_MESSAGE_DEFAULT_IGNORABLE 0 -#define AWS_IOT_ST_MESSAGE_DEFAULT_TYPE 0 -#define AWS_IOT_ST_MESSAGE_DEFAULT_PAYLOAD 0 - -#define AWS_IOT_ST_STREAM_ID_FIELD_NUMBER 2 -#define AWS_IOT_ST_IGNORABLE_FIELD_NUMBER 3 -#define AWS_IOT_ST_TYPE_FIELD_NUMBER 1 -#define AWS_IOT_ST_PAYLOAD_FIELD_NUMBER 4 +#define AWS_IOT_ST_MAXIMUM_VARINT 268435455 +#define AWS_IOT_ST_MAXIMUM_1_BYTE_VARINT_VALUE 128 +#define AWS_IOT_ST_MAXIMUM_2_BYTE_VARINT_VALUE 16384 +#define AWS_IOT_ST_MAXIMUM_3_BYTE_VARINT_VALUE 2097152 +#define AWS_IOT_ST_MAX_MESSAGE_SIZE (64 * 1024) -#define AWS_IOT_ST_DEFAULT_ALLO 60 -#define AWS_IOT_ST_MAX_MESSAGE_SIZE 64 * 1024 -#define AWS_IOT_ST_BLOCK_SIZE 1 - -enum aws_iot_st_message_type { UNKNOWN, DATA, STREAM_START, STREAM_RESET, SESSION_RESET }; +enum aws_secure_tunnel_field_number { + AWS_SECURE_TUNNEL_FN_TYPE = 1, + AWS_SECURE_TUNNEL_FN_STREAM_ID = 2, + AWS_SECURE_TUNNEL_FN_IGNORABLE = 3, + AWS_SECURE_TUNNEL_FN_PAYLOAD = 4, + AWS_SECURE_TUNNEL_FN_SERVICE_ID = 5, + AWS_SECURE_TUNNEL_FN_AVAILABLE_SERVICE_IDS = 6, + AWS_SECURE_TUNNEL_FN_CONNECTION_ID = 7, +}; -struct aws_iot_st_msg { - enum aws_iot_st_message_type type; - int32_t stream_id; - int ignorable; - struct aws_byte_buf payload; +enum aws_secure_tunnel_protocol_buffer_wire_type { + AWS_SECURE_TUNNEL_PBWT_VARINT = 0, /* int32, int64, uint32, uint64, sint32, sint64, bool, enum */ + AWS_SECURE_TUNNEL_PBWT_64_BIT = 1, /* fixed64, sfixed64, double */ + AWS_SECURE_TUNNEL_PBWT_LENGTH_DELIMITED = 2, /* string, bytes, embedded messages, packed repeated fields */ + AWS_SECURE_TUNNEL_PBWT_START_GROUP = 3, /* groups (deprecated) */ + AWS_SECURE_TUNNEL_PBWT_END_GROUP = 4, /* groups (deprecated) */ + AWS_SECURE_TUNNEL_PBWT_32_BIT = 5, /* fixed32, sfixed32, float */ }; +typedef void(aws_secure_tunnel_on_message_received_fn)( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view); + AWS_EXTERN_C_BEGIN AWS_IOTDEVICE_API -int aws_iot_st_msg_serialize_from_struct( +int aws_iot_st_msg_serialize_from_view( struct aws_byte_buf *buffer, struct aws_allocator *allocator, - struct aws_iot_st_msg message); + const struct aws_secure_tunnel_message_view *message_view); AWS_IOTDEVICE_API -int aws_iot_st_msg_deserialize_from_cursor( - struct aws_iot_st_msg *message, +int aws_secure_tunnel_deserialize_message_from_cursor( + struct aws_secure_tunnel *secure_tunnel, struct aws_byte_cursor *cursor, - struct aws_allocator *allocator); + aws_secure_tunnel_on_message_received_fn *on_message_received); + +AWS_IOTDEVICE_API +const char *aws_secure_tunnel_message_type_to_c_string(enum aws_secure_tunnel_message_type message_type); AWS_EXTERN_C_END diff --git a/include/aws/iotdevice/secure_tunneling.h b/include/aws/iotdevice/secure_tunneling.h index cc3dcf7a..d6e8f932 100644 --- a/include/aws/iotdevice/secure_tunneling.h +++ b/include/aws/iotdevice/secure_tunneling.h @@ -12,114 +12,278 @@ #define AWS_IOT_ST_SPLIT_MESSAGE_SIZE 15000 -enum aws_secure_tunneling_local_proxy_mode { AWS_SECURE_TUNNELING_SOURCE_MODE, AWS_SECURE_TUNNELING_DESTINATION_MODE }; - struct aws_secure_tunnel; -struct aws_websocket; -struct aws_websocket_incoming_frame; struct aws_http_proxy_options; +enum aws_secure_tunneling_local_proxy_mode { + AWS_SECURE_TUNNELING_SOURCE_MODE, + AWS_SECURE_TUNNELING_DESTINATION_MODE, +}; + +/** + * Type of IoT Secure Tunnel message. + * Enum values match IoT Secure Tunneling Local Proxy V3 Websocket Protocol Guide values. + * + * https://github.com/aws-samples/aws-iot-securetunneling-localproxy/blob/main/V3WebSocketProtocolGuide.md + */ +enum aws_secure_tunnel_message_type { + AWS_SECURE_TUNNEL_MT_UNKNOWN = 0, + + /** + * Data messages carry a payload with a sequence of bytes to write to the the active data stream + */ + AWS_SECURE_TUNNEL_MT_DATA = 1, + + /** + * StreamStart is the first message sent to start and establish a new and active data stream. This should only be + * sent from a Source to a Destination. + */ + AWS_SECURE_TUNNEL_MT_STREAM_START = 2, + + /** + * StreamReset messages convey that the data stream has ended, either in error, or closed intentionally for the + * tunnel peer. It is also sent to the source tunnel peer if an attempt to establish a new data stream fails on the + * destination side. + */ + AWS_SECURE_TUNNEL_MT_STREAM_RESET = 3, + + /** + * SessionReset messages can only originate from Secure Tunneling service if an internal data transmission error is + * detected. This will result in all active streams being closed. + */ + AWS_SECURE_TUNNEL_MT_SESSION_RESET = 4, + + /** + * ServiceIDs messages can only originate from the Secure Tunneling service and carry a list of unique service IDs + * used when opening a tunnel with services. + */ + AWS_SECURE_TUNNEL_MT_SERVICE_IDS = 5, + + /** + * ConnectionStart is the message sent to start and establish a new and active connection when the stream has been + * established and there's one active connection in the stream. + */ + AWS_SECURE_TUNNEL_MT_CONNECTION_START = 6, + + /** + * ConnectionReset messages convey that the connection has ended, either in error, or closed intentionally for the + * tunnel peer. + */ + AWS_SECURE_TUNNEL_MT_CONNECTION_RESET = 7 +}; + +/** + * Read-only snapshot of a Secure Tunnel Message + */ +struct aws_secure_tunnel_message_view { + + enum aws_secure_tunnel_message_type type; + + /** + * If a message is received and its type is unrecognized, and this field is set to true, it is ok for the tunnel + * client to ignore the message safely. If this field is unset, it must be considered as false. + */ + bool ignorable; + + int32_t stream_id; + + /** + * Secure tunnel multiplexing identifier + */ + struct aws_byte_cursor *service_id; + struct aws_byte_cursor *service_id_2; + struct aws_byte_cursor *service_id_3; + + struct aws_byte_cursor *payload; +}; + +/** + * Read-only snapshot of a Secure Tunnel Connection Completion Data + */ +struct aws_secure_tunnel_connection_view { + struct aws_byte_cursor *service_id_1; + struct aws_byte_cursor *service_id_2; + struct aws_byte_cursor *service_id_3; +}; + /* Callbacks */ -typedef void(aws_secure_tunneling_on_connection_complete_fn)(void *user_data); -typedef void(aws_secure_tunneling_on_connection_shutdown_fn)(void *user_data); + +/** + * Signature of callback to invoke on received messages + */ +typedef void( + aws_secure_tunnel_message_received_fn)(const struct aws_secure_tunnel_message_view *message, void *user_data); + +typedef void(aws_secure_tunneling_on_connection_complete_fn)( + const struct aws_secure_tunnel_connection_view *connection_view, + int error_code, + void *user_data); +typedef void(aws_secure_tunneling_on_connection_shutdown_fn)(int error_code, void *user_data); typedef void(aws_secure_tunneling_on_send_data_complete_fn)(int error_code, void *user_data); -typedef void(aws_secure_tunneling_on_data_receive_fn)(const struct aws_byte_buf *data, void *user_data); -typedef void(aws_secure_tunneling_on_stream_start_fn)(void *user_data); -typedef void(aws_secure_tunneling_on_stream_reset_fn)(void *user_data); +typedef void(aws_secure_tunneling_on_stream_start_fn)( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data); +typedef void(aws_secure_tunneling_on_stream_reset_fn)( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data); typedef void(aws_secure_tunneling_on_session_reset_fn)(void *user_data); +typedef void(aws_secure_tunneling_on_stopped_fn)(void *user_data); typedef void(aws_secure_tunneling_on_termination_complete_fn)(void *user_data); +/** + * Basic Secure Tunnel configuration struct. + * + * Contains connection properties for the creation of a Secure Tunnel + */ struct aws_secure_tunnel_options { - struct aws_allocator *allocator; + /** + * Host to establish Secure Tunnel connection to + */ + struct aws_byte_cursor endpoint_host; + + /** + * Secure Tunnel bootstrap to use whenever Secure Tunnel establishes a connection + */ struct aws_client_bootstrap *bootstrap; + + /** + * Socket options to use whenever this Secure Tunnel establishes a connection + */ const struct aws_socket_options *socket_options; + + /** + * (Optional) Http proxy options to use whenever this Secure Tunnel establishes a connection + */ const struct aws_http_proxy_options *http_proxy_options; + /** + * Access Token used to establish a Secure Tunnel connection + */ struct aws_byte_cursor access_token; - enum aws_secure_tunneling_local_proxy_mode local_proxy_mode; - struct aws_byte_cursor endpoint_host; + + /** + * (Optional) Client Token used to re-establish a Secure Tunnel connection after the one-time use access token has + * been used. If one is not provided, it will automatically be generated and re-used on subsequent reconnects. + */ + struct aws_byte_cursor client_token; + const char *root_ca; + aws_secure_tunnel_message_received_fn *on_message_received; + + void *user_data; + + enum aws_secure_tunneling_local_proxy_mode local_proxy_mode; + aws_secure_tunneling_on_connection_complete_fn *on_connection_complete; aws_secure_tunneling_on_connection_shutdown_fn *on_connection_shutdown; aws_secure_tunneling_on_send_data_complete_fn *on_send_data_complete; - aws_secure_tunneling_on_data_receive_fn *on_data_receive; aws_secure_tunneling_on_stream_start_fn *on_stream_start; aws_secure_tunneling_on_stream_reset_fn *on_stream_reset; aws_secure_tunneling_on_session_reset_fn *on_session_reset; - aws_secure_tunneling_on_termination_complete_fn *on_termination_complete; + aws_secure_tunneling_on_stopped_fn *on_stopped; - void *user_data; + /** + * Callback for when the secure tunnel has completely destroyed itself. + */ + aws_secure_tunneling_on_termination_complete_fn *on_termination_complete; + void *secure_tunnel_on_termination_user_data; }; -/* deprecated: "_config" is renamed "_options" for consistency with similar code in the aws-c libraries */ -#define aws_secure_tunneling_connection_config aws_secure_tunnel_options +/** + * Signature of callback to invoke when secure tunnel enters a fully disconnected state + */ +typedef void(aws_secure_tunnel_disconnect_completion_fn)(int error_code, void *complete_ctx); /** - * Persistent storage for aws_secure_tunnel_options. + * Public completion callback options for the DISCONNECT operation */ -struct aws_secure_tunnel_options_storage; +struct aws_secure_tunnel_disconnect_completion_options { + aws_secure_tunnel_disconnect_completion_fn *completion_callback; + void *completion_user_data; +}; AWS_EXTERN_C_BEGIN +/** + * Creates a new secure tunnel + * + * @param options secure tunnel configuration + * @return a new secure tunnel or NULL + */ AWS_IOTDEVICE_API -struct aws_secure_tunnel *aws_secure_tunnel_new(const struct aws_secure_tunnel_options *options); +struct aws_secure_tunnel *aws_secure_tunnel_new( + struct aws_allocator *allocator, + const struct aws_secure_tunnel_options *options); +/** + * Acquires a reference to a secure tunnel + * + * @param secure_tunnel secure tunnel to acquire a reference to. May be NULL + * @return what was passed in as the secure tunnel (a client or NULL) + */ AWS_IOTDEVICE_API struct aws_secure_tunnel *aws_secure_tunnel_acquire(struct aws_secure_tunnel *secure_tunnel); -AWS_IOTDEVICE_API -void aws_secure_tunnel_release(struct aws_secure_tunnel *secure_tunnel); - -AWS_IOTDEVICE_API -int aws_secure_tunnel_connect(struct aws_secure_tunnel *secure_tunnel); - -AWS_IOTDEVICE_API -int aws_secure_tunnel_close(struct aws_secure_tunnel *secure_tunnel); - -AWS_IOTDEVICE_API -int aws_secure_tunnel_send_data(struct aws_secure_tunnel *secure_tunnel, const struct aws_byte_cursor *data); - -AWS_IOTDEVICE_API -int aws_secure_tunnel_stream_start(struct aws_secure_tunnel *secure_tunnel); - -AWS_IOTDEVICE_API -int aws_secure_tunnel_stream_reset(struct aws_secure_tunnel *secure_tunnel); - /** - * Raises exception and returns AWS_OP_ERR if options are missing required parameters. + * Release a reference to a secure tunnel. When the secure tunnel ref count drops to zero, the secure tunnel + * will automatically trigger a stop and once the stop completes, the secure tunnel will delete itself. + * + * @param secure_tunnel secure tunnel to release a reference to. May be NULL + * @return NULL */ AWS_IOTDEVICE_API -int aws_secure_tunnel_options_validate(const struct aws_secure_tunnel_options *options); +struct aws_secure_tunnel *aws_secure_tunnel_release(struct aws_secure_tunnel *secure_tunnel); /** - * Create persistent storage for aws_secure_tunnel_options. - * Makes a deep copy of (or acquires reference to) any data referenced by options, + * Asynchronous notify to the secure tunnel that you want it to attempt to connect. + * The secure tunnel will attempt to stay connected. + * + * @param secure_tunnel secure tunnel to start + * @return success/failure in the synchronous logic that kicks off the start process */ AWS_IOTDEVICE_API -struct aws_secure_tunnel_options_storage *aws_secure_tunnel_options_storage_new( - const struct aws_secure_tunnel_options *options); +int aws_secure_tunnel_start(struct aws_secure_tunnel *secure_tunnel); /** - * Destroy options storage, and release any references held. + * Asynchronous notify to the secure tunnel that you want it to transition to the stopped state. When the + * secure tunnel reaches the stopped state, all session state is erased. + * + * @param secure_tunnel secure tunnel to stop + * @return success/failure in the synchronous logic that kicks off the start process */ AWS_IOTDEVICE_API -void aws_secure_tunnel_options_storage_destroy(struct aws_secure_tunnel_options_storage *storage); +int aws_secure_tunnel_stop(struct aws_secure_tunnel *secure_tunnel); /** - * Return pointer to options struct stored within. + * Queues a message operation in a secure tunnel + * + * @param secure_tunnel secure tunnel to queue a message for + * @param message_options configuration options for the message operation + * @return success/failure in the synchronous logic that kicks off the message operation */ AWS_IOTDEVICE_API -const struct aws_secure_tunnel_options *aws_secure_tunnel_options_storage_get( - const struct aws_secure_tunnel_options_storage *storage); +int aws_secure_tunnel_send_message( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options); -/* Making this exposed public to verify testing in the sdk layer */ +//*********************************************************************************************************************** +/* THIS API SHOULD ONLY BE USED FROM SOURCE MODE */ +//*********************************************************************************************************************** AWS_IOTDEVICE_API -bool on_websocket_incoming_frame_payload( - struct aws_websocket *websocket, - const struct aws_websocket_incoming_frame *frame, - struct aws_byte_cursor data, - void *user_data); +int aws_secure_tunnel_stream_start( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options); + +//*********************************************************************************************************************** +/* THIS API SHOULD NOT BE USED BY THE CUSTOMER AND SHOULD BE DEPRECATED */ +//*********************************************************************************************************************** +AWS_IOTDEVICE_API +int aws_secure_tunnel_stream_reset( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options); AWS_EXTERN_C_END diff --git a/source/iotdevice.c b/source/iotdevice.c index 1d3610c6..28e5475d 100644 --- a/source/iotdevice.c +++ b/source/iotdevice.c @@ -19,25 +19,71 @@ static struct aws_error_info s_errors[] = { "Bits marked as reserved were incorrectly set"), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_DEFENDER_INVALID_REPORT_INTERVAL, - "Invalid defender task reporting interval. Must be greater than 5 minutes"), + "Invalid defender task reporting interval. Must be greater than 5 minutes."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_DEFENDER_UNSUPPORTED_REPORT_FORMAT, - "Unknown format value selected for defender reporting task"), + "Unknown format value selected for defender reporting task."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_DEFENDER_REPORT_SERIALIZATION_FAILURE, - "Error serializing report for publishing"), + "Error serializing report for publishing."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_DEFENDER_UNKNOWN_CUSTOM_METRIC_TYPE, - "Unknown custom metric type found in reporting task"), + "Unknown custom metric type found in reporting task."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_DEFENDER_INVALID_TASK_CONFIG, - "Invalid configuration detected in defender reporting task config. Check prior errors"), + "Invalid configuration detected in defender reporting task config. Check prior errors."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_DEFENDER_PUBLISH_FAILURE, - "Mqtt client error while attempting to publish defender report"), + "Mqtt client error while attempting to publish defender report."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_DEFENDER_UNKNOWN_TASK_STATUS, - "Device defender task was invoked with an unknown task status"), + "Device defender task was invoked with an unknown task status."), + + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM, + "Secure Tunnel invalid stream id."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INCORRECT_MODE, + "Secure Tunnel stream cannot be started while in Destination Mode."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_BAD_SERVICE_ID, + "Secure Tunnel stream start request with bad service id."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION, + "Invalid Secure Tunnel data message options value."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_STREAM_OPTIONS_VALIDATION, + "Invalid Secure Tunnel stream options value."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_SECURE_TUNNEL_TERMINATED, + "Secure Tunnel terminated by user request."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_WEBSOCKET_TIMEOUT, + "Remote endpoint did not respond to connect request before timeout exceeded."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_PING_RESPONSE_TIMEOUT, + "Remote endpoint did not respond to a PINGREQ before timeout exceeded."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_FAILED_DUE_TO_DISCONNECTION, + "Secure Tunnel operation failed due to disconnected state."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_PROCESSING_FAILURE, + "Error while processing secure tunnel operational state."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_FAILED_DUE_TO_OFFLINE_QUEUE_POLICY, + "Error while processing secure tunnel operational state."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_UNEXPECTED_HANGUP, + "The connection was closed unexpectedly."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP, + "Secure Tunnel connection interrupted by user request."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_TERMINATED, + "Secure Tunnel terminated by user request."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DECODE_FAILURE, + "Error occured while decoding an incoming message." ), }; /* clang-format on */ #undef AWS_DEFINE_ERROR_INFO_IOTDEVICE diff --git a/source/secure_tunneling.c b/source/secure_tunneling.c index b063c38d..e0a2eb11 100644 --- a/source/secure_tunneling.c +++ b/source/secure_tunneling.c @@ -4,7 +4,9 @@ */ #include +#include +#include #include #include #include @@ -12,296 +14,328 @@ #include #include #include -#include #include +#include #include +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable : 4232) /* function pointer to dll symbol */ +#endif + #define MAX_WEBSOCKET_PAYLOAD 131076 #define INVALID_STREAM_ID 0 #define PAYLOAD_BYTE_LENGTH_PREFIX 2 +#define MIN_RECONNECT_DELAY_MS 1000 +#define MAX_RECONNECT_DELAY_MS 120000 #define PING_TASK_INTERVAL ((uint64_t)20 * 1000000000) +#define WEBSOCKET_HEADER_NAME_ACCESS_TOKEN "access-token" +#define WEBSOCKET_HEADER_NAME_CLIENT_TOKEN "client-token" +#define WEBSOCKET_HEADER_NAME_PROTOCOL "Sec-WebSocket-Protocol" +#define WEBSOCKET_HEADER_PROTOCOL_VALUE "aws.iot.securetunneling-2.0" + +static void s_change_current_state(struct aws_secure_tunnel *secure_tunnel, enum aws_secure_tunnel_state next_state); +void aws_secure_tunnel_operational_state_clean_up(struct aws_secure_tunnel *secure_tunnel); +static int s_aws_secure_tunnel_change_desired_state( + struct aws_secure_tunnel *secure_tunnel, + enum aws_secure_tunnel_state desired_state); +static void s_complete_operation_list( + struct aws_secure_tunnel *secure_tunnel, + struct aws_linked_list *operation_list, + int error_code); -#define UNUSED(x) (void)(x) - -struct aws_secure_tunnel_options_storage { - struct aws_secure_tunnel_options options; - - /* backup */ - struct aws_socket_options socket_options; - struct aws_http_proxy_options http_proxy_options; - struct aws_http_proxy_config *http_proxy_config; - struct aws_byte_buf cursor_storage; - struct aws_string *root_ca; -}; +static int s_secure_tunneling_send( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_view); -int aws_secure_tunnel_options_validate(const struct aws_secure_tunnel_options *options) { - AWS_ASSERT(options && options->allocator); - if (options->bootstrap == NULL) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "bootstrap cannot be NULL"); - return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - } - if (options->socket_options == NULL) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "socket options cannot be NULL"); - return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - } - if (options->access_token.len == 0) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "access token is required"); - return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - } - if (options->endpoint_host.len == 0) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "endpoint host is required"); - return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); - } +static void s_reevaluate_service_task(struct aws_secure_tunnel *secure_tunnel); - return AWS_OP_SUCCESS; -} +const char *aws_secure_tunnel_state_to_c_string(enum aws_secure_tunnel_state state) { + switch (state) { + case AWS_STS_STOPPED: + return "STOPPED"; -void aws_secure_tunnel_options_storage_destroy(struct aws_secure_tunnel_options_storage *storage) { - if (storage == NULL) { - return; - } + case AWS_STS_CONNECTING: + return "CONNECTING"; - aws_client_bootstrap_release(storage->options.bootstrap); - aws_http_proxy_config_destroy(storage->http_proxy_config); - aws_byte_buf_clean_up(&storage->cursor_storage); - aws_string_destroy(storage->root_ca); - aws_mem_release(storage->options.allocator, storage); -} + case AWS_STS_CONNECTED: + return "CONNECTED"; -struct aws_secure_tunnel_options_storage *aws_secure_tunnel_options_storage_new( - const struct aws_secure_tunnel_options *src) { + case AWS_STS_CLEAN_DISCONNECT: + return "CLEAN_DISCONNECT"; - if (aws_secure_tunnel_options_validate(src)) { - return NULL; - } + case AWS_STS_WEBSOCKET_SHUTDOWN: + return "WEBSOCKET_SHUTDOWN"; - struct aws_allocator *alloc = src->allocator; + case AWS_STS_PENDING_RECONNECT: + return "PENDING_RECONNECT"; - struct aws_secure_tunnel_options_storage *storage = - aws_mem_calloc(alloc, 1, sizeof(struct aws_secure_tunnel_options_storage)); + case AWS_STS_TERMINATED: + return "TERMINATED"; - /* shallow-copy everything that's shallow-copy-able */ - storage->options = *src; + default: + return "UNKNOWN"; + } +} - /* acquire reference to everything that's ref-counted */ - aws_client_bootstrap_acquire(storage->options.bootstrap); +static const char *s_get_proxy_mode_string(enum aws_secure_tunneling_local_proxy_mode local_proxy_mode) { + if (local_proxy_mode == AWS_SECURE_TUNNELING_SOURCE_MODE) { + return "source"; + } + return "destination"; +} - /* deep-copy anything that needs deep-copying */ - storage->socket_options = *src->socket_options; - storage->options.socket_options = &storage->socket_options; +static int s_reset_service_id(void *context, struct aws_hash_element *p_element) { + (void)context; + struct aws_service_id_element *service_id_elem = p_element->value; + service_id_elem->stream_id = INVALID_STREAM_ID; + return AWS_COMMON_HASH_TABLE_ITER_CONTINUE; +} - /* deep-copy the http-proxy-options to http_proxy_config */ - if (src->http_proxy_options != NULL) { - storage->http_proxy_config = - aws_http_proxy_config_new_tunneling_from_proxy_options(alloc, src->http_proxy_options); - if (storage->http_proxy_config == NULL) { - goto error; - } +/********************************************************************************************************************* + * Secure Tunnel Clean Up + ********************************************************************************************************************/ - /* Make a copy of http_proxy_options and point to it */ - aws_http_proxy_options_init_from_config(&storage->http_proxy_options, storage->http_proxy_config); - storage->options.http_proxy_options = &storage->http_proxy_options; +static void s_secure_tunnel_final_destroy(struct aws_secure_tunnel *secure_tunnel) { + if (secure_tunnel == NULL) { + AWS_LOGF_TRACE( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: secure_tunnel is NULL on final destroy", (void *)secure_tunnel); + return; } + AWS_LOGF_TRACE(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: secure_tunnel final destroy", (void *)secure_tunnel); - /* Store contents of all cursors within single buffer (and update cursors to point into it) */ - aws_byte_buf_init_cache_and_update_cursors( - &storage->cursor_storage, alloc, &storage->options.access_token, &storage->options.endpoint_host, NULL); - - if (src->root_ca != NULL) { - storage->root_ca = aws_string_new_from_c_str(alloc, src->root_ca); - storage->options.root_ca = aws_string_c_str(storage->root_ca); + aws_secure_tunneling_on_termination_complete_fn *on_termination_complete = NULL; + void *termination_complete_user_data = NULL; + if (secure_tunnel->config != NULL) { + on_termination_complete = secure_tunnel->config->on_termination_complete; + termination_complete_user_data = secure_tunnel->config->secure_tunnel_on_termination_user_data; } - return storage; + aws_secure_tunnel_operational_state_clean_up(secure_tunnel); -error: - aws_secure_tunnel_options_storage_destroy(storage); - return NULL; -} - -typedef int( - websocket_send_frame)(struct aws_websocket *websocket, const struct aws_websocket_send_frame_options *options); + /* Clean up all memory */ + aws_secure_tunnel_options_storage_destroy(secure_tunnel->config); + aws_http_message_release(secure_tunnel->handshake_request); + aws_byte_buf_clean_up(&secure_tunnel->received_data); + aws_tls_connection_options_clean_up(&secure_tunnel->tls_con_opt); + aws_tls_ctx_release(secure_tunnel->tls_ctx); + aws_mem_release(secure_tunnel->allocator, secure_tunnel); -static void s_send_websocket_ping(struct aws_websocket *websocket, websocket_send_frame *send_frame) { - if (!websocket) { - return; + if (on_termination_complete != NULL) { + (*on_termination_complete)(termination_complete_user_data); } - - struct aws_websocket_send_frame_options frame_options; - AWS_ZERO_STRUCT(frame_options); - frame_options.opcode = AWS_WEBSOCKET_OPCODE_PING; - frame_options.fin = true; - send_frame(websocket, &frame_options); } -struct ping_task_context { - struct aws_allocator *allocator; - struct aws_event_loop *event_loop; - - struct aws_task ping_task; - struct aws_atomic_var task_cancelled; - struct aws_websocket *websocket; +static void s_on_secure_tunnel_zero_ref_count(void *user_data) { + struct aws_secure_tunnel *secure_tunnel = user_data; + s_aws_secure_tunnel_change_desired_state(secure_tunnel, AWS_STS_TERMINATED); +} - /* The ping_task shares the vtable function used by the secure tunnel to send frames over the websocket. */ - websocket_send_frame *send_frame; -}; +/***************************************************************************************************************** + * RECEIVE MESSAGE HANDLING + *****************************************************************************************************************/ -static void s_ping_task(struct aws_task *task, void *user_data, enum aws_task_status task_status) { - AWS_LOGF_TRACE(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "s_ping_task"); +/* + * Close and reset all stream ids + */ +static void s_reset_secure_tunnel(struct aws_secure_tunnel *secure_tunnel) { + AWS_LOGF_INFO(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: Secure tunnel session reset.", (void *)secure_tunnel); - struct ping_task_context *ping_task_context = user_data; + secure_tunnel->config->stream_id = INVALID_STREAM_ID; + aws_hash_table_foreach(&secure_tunnel->config->service_ids, s_reset_service_id, NULL); + secure_tunnel->received_data.len = 0; /* Drop any incomplete secure tunnel frame */ +} - if (task_status == AWS_TASK_STATUS_CANCELED) { - AWS_LOGF_INFO( - AWS_LS_IOTDEVICE_SECURE_TUNNELING, "task_status is AWS_TASK_STATUS_CANCELED. Cleaning up ping task."); - aws_mem_release(ping_task_context->allocator, ping_task_context); - return; +static bool s_aws_secure_tunnel_stream_id_check_match( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_byte_cursor *service_id, + int32_t stream_id) { + /* No service id means V1 protocol is being used */ + if (service_id->len == 0) { + return (secure_tunnel->config->stream_id == stream_id); } - const size_t task_cancelled = aws_atomic_load_int(&ping_task_context->task_cancelled); - if (task_cancelled) { - AWS_LOGF_INFO(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "task_cancelled is true. Cleaning up ping task."); - aws_mem_release(ping_task_context->allocator, ping_task_context); - return; + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, service_id, &elem); + if (elem == NULL) { + AWS_LOGF_WARN( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel stream id check request for unsupported service_id: " PRInSTR, + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*service_id)); + return false; } - s_send_websocket_ping(ping_task_context->websocket, ping_task_context->send_frame); - - /* Schedule the next task */ - uint64_t now; - aws_event_loop_current_clock_time(ping_task_context->event_loop, &now); - aws_event_loop_schedule_task_future(ping_task_context->event_loop, task, now + PING_TASK_INTERVAL); + struct aws_service_id_element *service_id_elem = elem->value; + return (stream_id == service_id_elem->stream_id); } -static void s_on_websocket_setup(const struct aws_websocket_on_connection_setup_data *setup, void *user_data) { - - /* TODO: Handle error - * https://github.com/aws-samples/aws-iot-securetunneling-localproxy/blob/master/WebsocketProtocolGuide.md#handshake-error-responses - */ +static int s_aws_secure_tunnel_set_stream_id( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_byte_cursor *service_id, + int32_t stream_id) { + /* No service id means V1 protocol is being used */ + if (service_id == NULL || service_id->len == 0) { + secure_tunnel->config->stream_id = stream_id; + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel stream_id set to %d", + (void *)secure_tunnel, + stream_id); + return AWS_OP_SUCCESS; + } - struct aws_secure_tunnel *secure_tunnel = user_data; - aws_http_message_release(secure_tunnel->handshake_request); - secure_tunnel->handshake_request = NULL; + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, service_id, &elem); + if (elem == NULL) { + AWS_LOGF_WARN( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel request for unsupported service_id: " PRInSTR, + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*service_id)); + return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_BAD_SERVICE_ID; + } - secure_tunnel->websocket = setup->websocket; - secure_tunnel->options->on_connection_complete(secure_tunnel->options->user_data); + struct aws_service_id_element *replacement_elem = + aws_service_id_element_new(secure_tunnel->allocator, service_id, stream_id); - struct ping_task_context *ping_task_context = - aws_mem_acquire(secure_tunnel->alloc, sizeof(struct ping_task_context)); - secure_tunnel->ping_task_context = ping_task_context; - AWS_ZERO_STRUCT(*ping_task_context); - ping_task_context->allocator = secure_tunnel->alloc; - ping_task_context->event_loop = - aws_event_loop_group_get_next_loop(secure_tunnel->options->bootstrap->event_loop_group); - aws_atomic_store_int(&ping_task_context->task_cancelled, 0); - ping_task_context->websocket = setup->websocket; - ping_task_context->send_frame = secure_tunnel->websocket_vtable.send_frame; + aws_hash_table_put(&secure_tunnel->config->service_ids, &replacement_elem->service_id_cur, replacement_elem, NULL); + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel service_id '" PRInSTR "' stream_id set to %d", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*service_id), + stream_id); - aws_task_init(&ping_task_context->ping_task, s_ping_task, ping_task_context, "SecureTunnelingPingTask"); - aws_event_loop_schedule_task_now(ping_task_context->event_loop, &ping_task_context->ping_task); + return AWS_OP_SUCCESS; } -static void s_on_websocket_shutdown(struct aws_websocket *websocket, int error_code, void *user_data) { - UNUSED(websocket); - UNUSED(error_code); - - struct aws_secure_tunnel *secure_tunnel = user_data; - aws_atomic_store_int(&secure_tunnel->ping_task_context->task_cancelled, 1); - secure_tunnel->ping_task_context->websocket = NULL; - secure_tunnel->options->on_connection_shutdown(secure_tunnel->options->user_data); +static void s_aws_secure_tunnel_on_stream_start_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { + int result = s_aws_secure_tunnel_set_stream_id(secure_tunnel, message_view->service_id, message_view->stream_id); + if (secure_tunnel->config->on_stream_start) { + secure_tunnel->config->on_stream_start(message_view, result, secure_tunnel->config->user_data); + } } -static bool s_on_websocket_incoming_frame_begin( - struct aws_websocket *websocket, - const struct aws_websocket_incoming_frame *frame, - void *user_data) { - UNUSED(websocket); - UNUSED(frame); - UNUSED(user_data); - return true; +static void s_aws_secure_tunnel_on_stream_reset_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { + int result = AWS_OP_SUCCESS; + if (s_aws_secure_tunnel_stream_id_check_match(secure_tunnel, message_view->service_id, message_view->stream_id)) { + result = s_aws_secure_tunnel_set_stream_id(secure_tunnel, message_view->service_id, INVALID_STREAM_ID); + } + if (secure_tunnel->config->on_stream_reset) { + secure_tunnel->config->on_stream_reset(message_view, result, secure_tunnel->config->user_data); + } } -static void s_handle_stream_start(struct aws_secure_tunnel *secure_tunnel, struct aws_iot_st_msg *st_msg) { - if (secure_tunnel->options->local_proxy_mode == AWS_SECURE_TUNNELING_SOURCE_MODE) { - /* Source mode tunnel clients SHOULD treat receiving StreamStart as an error and close the active data stream - * and WebSocket connection. */ - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Received StreamStart in source mode. Closing the tunnel."); - secure_tunnel->vtable.close(secure_tunnel); - } else { - AWS_LOGF_INFO( - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "Received StreamStart in destination mode. stream_id=%d", - st_msg->stream_id); - secure_tunnel->stream_id = st_msg->stream_id; - secure_tunnel->options->on_stream_start(secure_tunnel->options->user_data); +static void s_aws_secure_tunnel_on_session_reset_received(struct aws_secure_tunnel *secure_tunnel) { + s_reset_secure_tunnel(secure_tunnel); + if (secure_tunnel->config->on_session_reset) { + secure_tunnel->config->on_session_reset(secure_tunnel->config->user_data); } } -static void s_reset_secure_tunnel(struct aws_secure_tunnel *secure_tunnel) { - secure_tunnel->stream_id = INVALID_STREAM_ID; - secure_tunnel->received_data.len = 0; /* Drop any incomplete secure tunnel frame */ -} +static void s_aws_secure_tunnel_on_service_ids_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { -static void s_handle_stream_reset(struct aws_secure_tunnel *secure_tunnel, struct aws_iot_st_msg *st_msg) { - if (secure_tunnel->stream_id == INVALID_STREAM_ID || secure_tunnel->stream_id != st_msg->stream_id) { - AWS_LOGF_WARN( + aws_hash_table_clear(&secure_tunnel->config->service_ids); + + if (message_view->service_id != NULL) { + struct aws_service_id_element *service_id_1_elem = + aws_service_id_element_new(secure_tunnel->allocator, message_view->service_id, INVALID_STREAM_ID); + aws_hash_table_put( + &secure_tunnel->config->service_ids, &service_id_1_elem->service_id_cur, service_id_1_elem, NULL); + AWS_LOGF_INFO( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "Received StreamReset with stream_id different than the active stream_id. Ignoring. st_msg->stream_id=%d " - "secure_tunnel->stream_id=%d", - st_msg->stream_id, - secure_tunnel->stream_id); - return; + "id=%p: secure tunnel service id 1 set to: " PRInSTR, + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id)); + if (message_view->service_id_2 != NULL) { + struct aws_service_id_element *service_id_2_elem = + aws_service_id_element_new(secure_tunnel->allocator, message_view->service_id_2, INVALID_STREAM_ID); + aws_hash_table_put( + &secure_tunnel->config->service_ids, &service_id_2_elem->service_id_cur, service_id_2_elem, NULL); + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: secure tunnel service id 2 set to: " PRInSTR, + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id_2)); + if (message_view->service_id_3 != NULL) { + struct aws_service_id_element *service_id_3_elem = + aws_service_id_element_new(secure_tunnel->allocator, message_view->service_id_3, INVALID_STREAM_ID); + aws_hash_table_put( + &secure_tunnel->config->service_ids, &service_id_3_elem->service_id_cur, service_id_3_elem, NULL); + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: secure tunnel service id 3 set to: " PRInSTR, + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id_3)); + } + } } - secure_tunnel->options->on_stream_reset(secure_tunnel->options->user_data); - s_reset_secure_tunnel(secure_tunnel); -} + struct aws_secure_tunnel_connection_view connection_view; + AWS_ZERO_STRUCT(connection_view); + connection_view.service_id_1 = message_view->service_id; + connection_view.service_id_2 = message_view->service_id_2; + connection_view.service_id_3 = message_view->service_id_3; -static void s_handle_session_reset(struct aws_secure_tunnel *secure_tunnel) { - if (secure_tunnel->stream_id == INVALID_STREAM_ID) { /* Session reset does not need to check stream id */ - return; + /* A connection can only be used once available service ids are established with the secure tunnel. */ + if (secure_tunnel->config->on_connection_complete) { + secure_tunnel->config->on_connection_complete( + &connection_view, AWS_ERROR_SUCCESS, secure_tunnel->config->user_data); } - - secure_tunnel->options->on_session_reset(secure_tunnel->options->user_data); - s_reset_secure_tunnel(secure_tunnel); } -static void s_process_iot_st_msg(struct aws_secure_tunnel *secure_tunnel, struct aws_iot_st_msg *st_msg) { - /* TODO: Check stream_id, send reset? */ - - switch (st_msg->type) { - case DATA: - secure_tunnel->options->on_data_receive(&st_msg->payload, secure_tunnel->options->user_data); +static void s_aws_secure_tunnel_connected_on_message_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { + aws_secure_tunnel_message_view_log(message_view, AWS_LL_DEBUG); + switch (message_view->type) { + case AWS_SECURE_TUNNEL_MT_DATA: + if (secure_tunnel->config->on_message_received) { + secure_tunnel->config->on_message_received(message_view, secure_tunnel->config->user_data); + } break; - case STREAM_START: - s_handle_stream_start(secure_tunnel, st_msg); + case AWS_SECURE_TUNNEL_MT_STREAM_START: + s_aws_secure_tunnel_on_stream_start_received(secure_tunnel, message_view); break; - case STREAM_RESET: - s_handle_stream_reset(secure_tunnel, st_msg); + case AWS_SECURE_TUNNEL_MT_STREAM_RESET: + s_aws_secure_tunnel_on_stream_reset_received(secure_tunnel, message_view); break; - case SESSION_RESET: - s_handle_session_reset(secure_tunnel); + case AWS_SECURE_TUNNEL_MT_SESSION_RESET: + s_aws_secure_tunnel_on_session_reset_received(secure_tunnel); break; - case UNKNOWN: + case AWS_SECURE_TUNNEL_MT_SERVICE_IDS: + s_aws_secure_tunnel_on_service_ids_received(secure_tunnel, message_view); + break; + case AWS_SECURE_TUNNEL_MT_CONNECTION_START: + case AWS_SECURE_TUNNEL_MT_CONNECTION_RESET: + case AWS_SECURE_TUNNEL_MT_UNKNOWN: default: - if (!st_msg->ignorable) { - AWS_LOGF_WARN( + if (!message_view->ignorable) { + AWS_LOGF_ERROR( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "Encountered an unknown but un-ignorable message. type=%d", - st_msg->type); + "Encountered an unknown but un-ignorable message. type=%s", + aws_secure_tunnel_message_type_to_c_string(message_view->type)); } break; } } -static void s_process_received_data(struct aws_secure_tunnel *secure_tunnel) { +static int s_process_received_data(struct aws_secure_tunnel *secure_tunnel) { struct aws_byte_buf *received_data = &secure_tunnel->received_data; struct aws_byte_cursor cursor = aws_byte_cursor_from_buf(received_data); - uint16_t data_length = 0; - struct aws_byte_cursor tmp_cursor = - cursor; /* If there are at least two bytes for the data_length, but not enough */ - /* data for a complete secure tunnel frame, we don't want to move `cursor`. */ + /* + * If there are at least two bytes for the data_length, but not enough data for a complete secure tunnel frame, we + * don't want to move `cursor`. + */ + struct aws_byte_cursor tmp_cursor = cursor; while (aws_byte_cursor_read_be16(&tmp_cursor, &data_length) && tmp_cursor.len >= data_length) { cursor = tmp_cursor; @@ -309,162 +343,54 @@ static void s_process_received_data(struct aws_secure_tunnel *secure_tunnel) { aws_byte_cursor_advance(&cursor, data_length); tmp_cursor = cursor; - struct aws_iot_st_msg st_msg; - aws_iot_st_msg_deserialize_from_cursor(&st_msg, &st_frame, secure_tunnel->alloc); - s_process_iot_st_msg(secure_tunnel, &st_msg); - - if (st_msg.type == DATA) { - aws_byte_buf_clean_up(&st_msg.payload); + if (aws_secure_tunnel_deserialize_message_from_cursor( + secure_tunnel, &st_frame, &s_aws_secure_tunnel_connected_on_message_received)) { + int error_code = aws_last_error(); + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to deserialize message with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + return error_code; } } if (cursor.ptr != received_data->buffer) { - /* TODO: Consider better data structure that doesn't require moving bytes */ - /* Move unprocessed data to the beginning */ received_data->len = 0; aws_byte_buf_append(received_data, &cursor); } -} - -bool on_websocket_incoming_frame_payload( - struct aws_websocket *websocket, - const struct aws_websocket_incoming_frame *frame, - struct aws_byte_cursor data, - void *user_data) { - - UNUSED(websocket); - UNUSED(frame); - - if (data.len > 0) { - struct aws_secure_tunnel *secure_tunnel = user_data; - aws_byte_buf_append(&secure_tunnel->received_data, &data); - s_process_received_data(secure_tunnel); - } - - return true; -} - -static bool s_on_websocket_incoming_frame_complete( - struct aws_websocket *websocket, - const struct aws_websocket_incoming_frame *frame, - int error_code, - void *user_data) { - UNUSED(websocket); - UNUSED(frame); - UNUSED(error_code); - UNUSED(user_data); - - /* TODO: Check error_code */ - - return true; -} - -static const char *s_get_proxy_mode_string(enum aws_secure_tunneling_local_proxy_mode local_proxy_mode) { - if (local_proxy_mode == AWS_SECURE_TUNNELING_SOURCE_MODE) { - return "source"; - } - - return "destination"; -} - -static struct aws_http_message *s_new_handshake_request(const struct aws_secure_tunnel *secure_tunnel) { - char path[50]; - snprintf( - path, - sizeof(path), - "/tunnel?local-proxy-mode=%s", - s_get_proxy_mode_string(secure_tunnel->options->local_proxy_mode)); - struct aws_http_message *handshake_request = aws_http_message_new_websocket_handshake_request( - secure_tunnel->alloc, aws_byte_cursor_from_c_str(path), secure_tunnel->options->endpoint_host); - - struct aws_http_header extra_headers[] = { - { - .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("Sec-WebSocket-Protocol"), - .value = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("aws.iot.securetunneling-1.0"), - }, - { - .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("access-token"), - .value = secure_tunnel->options->access_token, - }, - }; - for (size_t i = 0; i < AWS_ARRAY_SIZE(extra_headers); ++i) { - aws_http_message_add_header(handshake_request, extra_headers[i]); - } - - return handshake_request; -} - -void init_websocket_client_connection_options( - struct aws_secure_tunnel *secure_tunnel, - struct aws_websocket_client_connection_options *websocket_options) { - - AWS_ZERO_STRUCT(*websocket_options); - websocket_options->allocator = secure_tunnel->alloc; - websocket_options->bootstrap = secure_tunnel->options->bootstrap; - websocket_options->socket_options = secure_tunnel->options->socket_options; - websocket_options->tls_options = &secure_tunnel->tls_con_opt; - websocket_options->host = secure_tunnel->options->endpoint_host; - websocket_options->port = 443; - websocket_options->handshake_request = s_new_handshake_request(secure_tunnel); - websocket_options->initial_window_size = MAX_WEBSOCKET_PAYLOAD; /* TODO: followup */ - websocket_options->user_data = secure_tunnel; - websocket_options->proxy_options = secure_tunnel->options->http_proxy_options; - websocket_options->on_connection_setup = s_on_websocket_setup; - websocket_options->on_connection_shutdown = s_on_websocket_shutdown; - websocket_options->on_incoming_frame_begin = s_on_websocket_incoming_frame_begin; - websocket_options->on_incoming_frame_payload = on_websocket_incoming_frame_payload; - websocket_options->on_incoming_frame_complete = s_on_websocket_incoming_frame_complete; - websocket_options->manual_window_management = false; - - /* Save handshake_request to release it later */ - secure_tunnel->handshake_request = websocket_options->handshake_request; -} - -static int s_secure_tunneling_connect(struct aws_secure_tunnel *secure_tunnel) { - if (secure_tunnel == NULL || secure_tunnel->stream_id != INVALID_STREAM_ID) { - return AWS_OP_ERR; - } - - struct aws_websocket_client_connection_options websocket_options; - init_websocket_client_connection_options(secure_tunnel, &websocket_options); - if (secure_tunnel->websocket_vtable.client_connect(&websocket_options)) { - return AWS_OP_ERR; - } return AWS_OP_SUCCESS; } -static int s_secure_tunneling_close(struct aws_secure_tunnel *secure_tunnel) { - if (secure_tunnel == NULL) { - return AWS_OP_ERR; - } - - s_reset_secure_tunnel(secure_tunnel); - if (secure_tunnel->websocket != NULL) { - secure_tunnel->websocket_vtable.close(secure_tunnel->websocket, false); - secure_tunnel->websocket_vtable.release(secure_tunnel->websocket); - secure_tunnel->websocket = NULL; - } - return AWS_OP_SUCCESS; -} +/***************************************************************************************************************** + * SEND MESSAGE HANDLING + *****************************************************************************************************************/ -static void s_secure_tunneling_on_send_data_complete_callback( +static void s_secure_tunneling_websocket_on_send_data_complete_callback( struct aws_websocket *websocket, int error_code, void *user_data) { - UNUSED(websocket); + (void)websocket; struct data_tunnel_pair *pair = user_data; struct aws_secure_tunnel *secure_tunnel = (struct aws_secure_tunnel *)pair->secure_tunnel; - secure_tunnel->options->on_send_data_complete(error_code, pair->secure_tunnel->options->user_data); - aws_byte_buf_clean_up(&pair->buf); - aws_mem_release(secure_tunnel->alloc, pair); + if (secure_tunnel->config->on_send_data_complete) { + secure_tunnel->config->on_send_data_complete(error_code, pair->secure_tunnel->config->user_data); + } + aws_secure_tunnel_data_tunnel_pair_destroy(pair); + secure_tunnel->pending_write_completion = false; } -bool secure_tunneling_send_data_call(struct aws_websocket *websocket, struct aws_byte_buf *out_buf, void *user_data) { - UNUSED(websocket); +static bool secure_tunneling_websocket_stream_outgoing_payload( + struct aws_websocket *websocket, + struct aws_byte_buf *out_buf, + void *user_data) { + (void)websocket; struct data_tunnel_pair *pair = user_data; size_t space_available = out_buf->capacity - out_buf->len; + if ((pair->length_prefix_written == false) && (space_available >= PAYLOAD_BYTE_LENGTH_PREFIX)) { if (aws_byte_buf_write_be16(out_buf, (int16_t)pair->buf.len) == false) { AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Failure writing buffer length prefix to out_buf"); @@ -473,159 +399,1365 @@ bool secure_tunneling_send_data_call(struct aws_websocket *websocket, struct aws pair->length_prefix_written = true; space_available = out_buf->capacity - out_buf->len; } + if (pair->length_prefix_written == true) { - size_t bytes_max = pair->cur.len; - size_t amount_to_send = bytes_max < space_available ? bytes_max : space_available; - - struct aws_byte_cursor send_cursor = aws_byte_cursor_advance(&pair->cur, amount_to_send); - if (send_cursor.len) { - if (aws_byte_buf_write_from_whole_cursor(out_buf, send_cursor) == false) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Failure writing data to out_buf"); - return false; - } - } + pair->cur = aws_byte_buf_write_to_capacity(out_buf, &pair->cur); } + return true; } -static void s_init_websocket_send_frame_options( - struct aws_websocket_send_frame_options *frame_options, - struct data_tunnel_pair *pair) { - +static void s_init_websocket_frame_options( + struct data_tunnel_pair *pair, + struct aws_websocket_send_frame_options *frame_options) { AWS_ZERO_STRUCT(*frame_options); frame_options->payload_length = pair->buf.len + PAYLOAD_BYTE_LENGTH_PREFIX; frame_options->user_data = pair; - frame_options->stream_outgoing_payload = secure_tunneling_send_data_call; - frame_options->on_complete = s_secure_tunneling_on_send_data_complete_callback; + frame_options->stream_outgoing_payload = secure_tunneling_websocket_stream_outgoing_payload; + frame_options->on_complete = s_secure_tunneling_websocket_on_send_data_complete_callback; frame_options->opcode = AWS_WEBSOCKET_OPCODE_BINARY; frame_options->fin = true; } -static int s_init_data_tunnel_pair( - struct data_tunnel_pair *pair, - struct aws_secure_tunnel *secure_tunnel, - const struct aws_byte_cursor *data, - enum aws_iot_st_message_type type) { - struct aws_iot_st_msg message; - message.stream_id = secure_tunnel->stream_id; - message.ignorable = 0; - message.type = type; - if (data != NULL) { - message.payload.buffer = data->ptr; - message.payload.len = data->len; - } else { - message.payload.buffer = NULL; - message.payload.len = 0; - } - pair->secure_tunnel = secure_tunnel; - pair->length_prefix_written = false; - if (aws_iot_st_msg_serialize_from_struct(&pair->buf, secure_tunnel->alloc, message) != AWS_OP_SUCCESS) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Failure serializing message"); - goto cleanup; - } - if (pair->buf.len > AWS_IOT_ST_MAX_MESSAGE_SIZE) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Message size greater than AWS_IOT_ST_MAX_MESSAGE_SIZE"); - goto cleanup; - } - pair->cur = aws_byte_cursor_from_buf(&pair->buf); - return AWS_OP_SUCCESS; -cleanup: - aws_byte_buf_clean_up(&pair->buf); - aws_mem_release(pair->secure_tunnel->alloc, (void *)pair); - return AWS_OP_ERR; -} - int secure_tunneling_init_send_frame( - struct aws_websocket_send_frame_options *frame_options, struct aws_secure_tunnel *secure_tunnel, - const struct aws_byte_cursor *data, - enum aws_iot_st_message_type type) { + struct aws_websocket_send_frame_options *frame_options, + const struct aws_secure_tunnel_message_view *message_view) { + struct data_tunnel_pair *pair = - (struct data_tunnel_pair *)aws_mem_acquire(secure_tunnel->alloc, sizeof(struct data_tunnel_pair)); - if (s_init_data_tunnel_pair(pair, secure_tunnel, data, type) != AWS_OP_SUCCESS) { + aws_secure_tunnel_data_tunnel_pair_new(secure_tunnel->allocator, secure_tunnel, message_view); + + if (!pair) { return AWS_OP_ERR; } - s_init_websocket_send_frame_options(frame_options, pair); + + s_init_websocket_frame_options(pair, frame_options); return AWS_OP_SUCCESS; } static int s_secure_tunneling_send( struct aws_secure_tunnel *secure_tunnel, - const struct aws_byte_cursor *data, - enum aws_iot_st_message_type type) { - + const struct aws_secure_tunnel_message_view *message_view) { struct aws_websocket_send_frame_options frame_options; - if (secure_tunneling_init_send_frame(&frame_options, secure_tunnel, data, type) != AWS_OP_SUCCESS) { + if (secure_tunneling_init_send_frame(secure_tunnel, &frame_options, message_view)) { return AWS_OP_ERR; } - return secure_tunnel->websocket_vtable.send_frame(secure_tunnel->websocket, &frame_options); + + /* Prevent further operations that attempt to write to the WebSocket until current operation is completed */ + secure_tunnel->pending_write_completion = true; + return secure_tunnel->vtable->aws_websocket_send_frame_fn(secure_tunnel->websocket, &frame_options); } -static int s_secure_tunneling_send_data(struct aws_secure_tunnel *secure_tunnel, const struct aws_byte_cursor *data) { - if (secure_tunnel->stream_id == INVALID_STREAM_ID) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Invalid Stream Id"); - return AWS_ERROR_IOTDEVICE_SECUTRE_TUNNELING_INVALID_STREAM; - } - struct aws_byte_cursor new_data = *data; - while (new_data.len) { - size_t bytes_max = new_data.len; - size_t amount_to_send = bytes_max < AWS_IOT_ST_SPLIT_MESSAGE_SIZE ? bytes_max : AWS_IOT_ST_SPLIT_MESSAGE_SIZE; +/***************************************************************************************************************** + * Websocket + *****************************************************************************************************************/ +typedef int( + websocket_send_frame)(struct aws_websocket *websocket, const struct aws_websocket_send_frame_options *options); - struct aws_byte_cursor send_cursor = aws_byte_cursor_advance(&new_data, amount_to_send); - AWS_FATAL_ASSERT(send_cursor.len > 0); - if (send_cursor.len) { - if (s_secure_tunneling_send(secure_tunnel, &send_cursor, DATA) != AWS_OP_SUCCESS) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Failure writing data to out_buf"); - return AWS_OP_ERR; - } - } - } - return AWS_OP_SUCCESS; +static bool s_on_websocket_incoming_frame_begin( + struct aws_websocket *websocket, + const struct aws_websocket_incoming_frame *frame, + void *user_data) { + (void)websocket; + (void)frame; + (void)user_data; + return true; } -static int s_secure_tunneling_send_stream_start(struct aws_secure_tunnel *secure_tunnel) { - if (secure_tunnel->options->local_proxy_mode == AWS_SECURE_TUNNELING_DESTINATION_MODE) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Start can only be sent from src mode"); - return AWS_ERROR_IOTDEVICE_SECUTRE_TUNNELING_INCORRECT_MODE; - } - secure_tunnel->stream_id += 1; - if (secure_tunnel->stream_id == 0) { - secure_tunnel->stream_id += 1; - } - return s_secure_tunneling_send(secure_tunnel, NULL, STREAM_START); -} +static bool s_on_websocket_incoming_frame_payload( + struct aws_websocket *websocket, + const struct aws_websocket_incoming_frame *frame, + struct aws_byte_cursor data, + void *user_data) { + + (void)websocket; + (void)frame; -static int s_secure_tunneling_send_stream_reset(struct aws_secure_tunnel *secure_tunnel) { - if (secure_tunnel->stream_id == INVALID_STREAM_ID) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Invalid Stream Id"); - return AWS_ERROR_IOTDEVICE_SECUTRE_TUNNELING_INVALID_STREAM; + if (data.len > 0) { + struct aws_secure_tunnel *secure_tunnel = user_data; + aws_byte_buf_append(&secure_tunnel->received_data, &data); + if (s_process_received_data(secure_tunnel)) { + return false; + } } - int result = s_secure_tunneling_send(secure_tunnel, NULL, STREAM_RESET); - s_reset_secure_tunnel(secure_tunnel); - return result; + return true; } -static void s_secure_tunnel_destroy(void *user_data); +static bool s_on_websocket_incoming_frame_complete( + struct aws_websocket *websocket, + const struct aws_websocket_incoming_frame *frame, + int error_code, + void *user_data) { + (void)websocket; + (void)frame; -struct aws_secure_tunnel *aws_secure_tunnel_new(const struct aws_secure_tunnel_options *options) { + if (error_code) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Error on s_on_websocket_incoming_frame_complete() with error %d(%s).", + (void *)user_data, + error_code, + aws_error_debug_str(error_code)); + } - struct aws_tls_ctx_options tls_ctx_opt; - AWS_ZERO_STRUCT(tls_ctx_opt); + return true; +} - struct aws_secure_tunnel *secure_tunnel = aws_mem_calloc(options->allocator, 1, sizeof(struct aws_secure_tunnel)); - secure_tunnel->alloc = options->allocator; - aws_ref_count_init(&secure_tunnel->ref_count, secure_tunnel, s_secure_tunnel_destroy); +static void s_secure_tunnel_shutdown(struct aws_client_bootstrap *bootstrap, int error_code, void *user_data) { + (void)bootstrap; + struct aws_secure_tunnel *secure_tunnel = user_data; - /* store options */ - secure_tunnel->options_storage = aws_secure_tunnel_options_storage_new(options); - if (secure_tunnel->options_storage == NULL) { - goto error; + if (error_code == AWS_ERROR_SUCCESS) { + error_code = AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_UNEXPECTED_HANGUP; } - secure_tunnel->options = &secure_tunnel->options_storage->options; - /* tls_ctx */ - aws_tls_ctx_options_init_default_client(&tls_ctx_opt, options->allocator); + /* fail current and all pending operations */ + if (secure_tunnel->current_operation != NULL) { + aws_linked_list_push_front(&secure_tunnel->queued_operations, &secure_tunnel->current_operation->node); + secure_tunnel->current_operation = NULL; + } + + if (!aws_linked_list_empty(&secure_tunnel->queued_operations)) { + s_complete_operation_list( + secure_tunnel, + &secure_tunnel->queued_operations, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_FAILED_DUE_TO_OFFLINE_QUEUE_POLICY); + } +} + +/* Normal call to shutdown the websocket */ +static void s_secure_tunnel_shutdown_websocket(struct aws_secure_tunnel *secure_tunnel, int error_code) { + (void)error_code; + if (secure_tunnel->current_state != AWS_STS_CONNECTED && secure_tunnel->current_state != AWS_STS_CLEAN_DISCONNECT) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: secure tunnel websocket shutdown invoked from unexpected state %d(%s)", + (void *)secure_tunnel, + (int)secure_tunnel->current_state, + aws_secure_tunnel_state_to_c_string(secure_tunnel->current_state)); + return; + } + + s_change_current_state(secure_tunnel, AWS_STS_WEBSOCKET_SHUTDOWN); +} + +/* Called by websocket when it's destroyed or manually on failed websocket creation */ +static void s_on_websocket_shutdown(struct aws_websocket *websocket, int error_code, void *user_data) { + struct aws_secure_tunnel *secure_tunnel = user_data; + s_secure_tunnel_shutdown(secure_tunnel->config->bootstrap, error_code, secure_tunnel); + + secure_tunnel->vtable->aws_websocket_release_fn(websocket); + websocket = NULL; + + if (secure_tunnel->config->on_connection_shutdown) { + secure_tunnel->config->on_connection_shutdown(error_code, secure_tunnel->config->user_data); + } + + if (secure_tunnel->desired_state == AWS_STS_CONNECTED) { + s_change_current_state(secure_tunnel, AWS_STS_PENDING_RECONNECT); + } else { + s_change_current_state(secure_tunnel, AWS_STS_STOPPED); + } +} + +/* Called on successful or failed websocket setup attempt */ +static void s_on_websocket_setup(const struct aws_websocket_on_connection_setup_data *setup, void *user_data) { + struct aws_secure_tunnel *secure_tunnel = user_data; + secure_tunnel->handshake_request = aws_http_message_release(secure_tunnel->handshake_request); + + /* Setup callback contract is: if error_code is non-zero then websocket is NULL. */ + AWS_FATAL_ASSERT((setup->error_code != 0) == (setup->websocket == NULL)); + + secure_tunnel->websocket = setup->websocket; + + if (setup->error_code != AWS_OP_SUCCESS) { + /* Report a failed WebSocket Upgrade attempt */ + if (secure_tunnel->config->on_connection_complete) { + secure_tunnel->config->on_connection_complete(NULL, setup->error_code, secure_tunnel->config->user_data); + } + /* Failed/Successful websocket creation and associated errors logged by "websocket-setup" */ + s_on_websocket_shutdown(secure_tunnel->websocket, setup->error_code, secure_tunnel); + return; + } + + AWS_FATAL_ASSERT(secure_tunnel->current_state == AWS_STS_CONNECTING); + AWS_FATAL_ASSERT(aws_event_loop_thread_is_callers_thread(secure_tunnel->loop)); + + if (secure_tunnel->desired_state != AWS_STS_CONNECTED) { + aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP); + goto error; + } + + s_change_current_state(secure_tunnel, AWS_STS_CONNECTED); + + return; +error: + s_on_websocket_shutdown(secure_tunnel->websocket, setup->error_code, secure_tunnel); +} + +struct aws_secure_tunnel_websocket_transform_complete_task { + struct aws_task task; + struct aws_allocator *allocator; + struct aws_secure_tunnel *secure_tunnel; + int error_code; + struct aws_http_message *handshake; +}; + +void s_websocket_transform_complete_task_fn(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + + struct aws_secure_tunnel_websocket_transform_complete_task *websocket_transform_complete_task = arg; + if (status != AWS_TASK_STATUS_RUN_READY) { + goto done; + } + + struct aws_secure_tunnel *secure_tunnel = websocket_transform_complete_task->secure_tunnel; + + aws_http_message_release(secure_tunnel->handshake_request); + secure_tunnel->handshake_request = aws_http_message_acquire(websocket_transform_complete_task->handshake); + + int error_code = websocket_transform_complete_task->error_code; + if (error_code == 0 && secure_tunnel->desired_state == AWS_STS_CONNECTED) { + struct aws_websocket_client_connection_options websocket_options = { + .allocator = secure_tunnel->allocator, + .bootstrap = secure_tunnel->config->bootstrap, + .socket_options = &secure_tunnel->config->socket_options, + .tls_options = &secure_tunnel->tls_con_opt, + .host = aws_byte_cursor_from_string(secure_tunnel->config->endpoint_host), + .port = 443, + .handshake_request = secure_tunnel->handshake_request, + .manual_window_management = false, + .user_data = secure_tunnel, + .requested_event_loop = secure_tunnel->loop, + + .on_connection_setup = s_on_websocket_setup, + .on_connection_shutdown = s_on_websocket_shutdown, + .on_incoming_frame_begin = s_on_websocket_incoming_frame_begin, + .on_incoming_frame_payload = s_on_websocket_incoming_frame_payload, + .on_incoming_frame_complete = s_on_websocket_incoming_frame_complete, + }; + + if (secure_tunnel->config->http_proxy_config != NULL) { + websocket_options.proxy_options = &secure_tunnel->config->http_proxy_options; + } + + if (secure_tunnel->vtable->aws_websocket_client_connect_fn(&websocket_options)) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Failed to initiate websocket connection.", + (void *)secure_tunnel); + error_code = aws_last_error(); + goto error; + } + + goto done; + } else { + if (error_code == AWS_ERROR_SUCCESS) { + AWS_ASSERT(secure_tunnel->desired_state != AWS_STS_CONNECTED); + error_code = AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP; + } + } + +error:; + struct aws_websocket_on_connection_setup_data websocket_setup = {.error_code = error_code}; + s_on_websocket_setup(&websocket_setup, secure_tunnel); + +done: + aws_http_message_release(websocket_transform_complete_task->handshake); + aws_secure_tunnel_release(websocket_transform_complete_task->secure_tunnel); + aws_mem_release(websocket_transform_complete_task->allocator, websocket_transform_complete_task); +} + +static int s_handshake_add_header( + const struct aws_secure_tunnel *secure_tunnel, + struct aws_http_message *handshake, + struct aws_http_header header) { + if (aws_http_message_add_header(handshake, header)) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Failed to add header to websocket handshake request", + (void *)secure_tunnel); + return AWS_OP_ERR; + } + AWS_LOGF_TRACE( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Added header " PRInSTR " " PRInSTR " to websocket request", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(header.name), + AWS_BYTE_CURSOR_PRI(header.value)); + return AWS_OP_SUCCESS; +} + +static struct aws_http_message *s_new_handshake_request(const struct aws_secure_tunnel *secure_tunnel) { + char path[50]; + snprintf( + path, + sizeof(path), + "/tunnel?local-proxy-mode=%s", + s_get_proxy_mode_string(secure_tunnel->config->local_proxy_mode)); + + struct aws_http_message *handshake = aws_http_message_new_websocket_handshake_request( + secure_tunnel->allocator, + aws_byte_cursor_from_c_str(path), + aws_byte_cursor_from_string(secure_tunnel->config->endpoint_host)); + + if (handshake == NULL) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: Failed to generate handshake request.", (void *)secure_tunnel); + goto error; + } + + /* Secure Tunnel specific headers */ + struct aws_http_header header_protocol = { + .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL(WEBSOCKET_HEADER_NAME_PROTOCOL), + .value = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL(WEBSOCKET_HEADER_PROTOCOL_VALUE), + }; + if (s_handshake_add_header(secure_tunnel, handshake, header_protocol)) { + goto error; + } + + struct aws_http_header header_access_token = { + .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL(WEBSOCKET_HEADER_NAME_ACCESS_TOKEN), + .value = aws_byte_cursor_from_string(secure_tunnel->config->access_token), + }; + if (s_handshake_add_header(secure_tunnel, handshake, header_access_token)) { + goto error; + } + + if (secure_tunnel->config->client_token) { + struct aws_http_header header_client_token = { + .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL(WEBSOCKET_HEADER_NAME_CLIENT_TOKEN), + .value = aws_byte_cursor_from_string(secure_tunnel->config->client_token), + }; + if (s_handshake_add_header(secure_tunnel, handshake, header_client_token)) { + goto error; + } + } + + return handshake; + +error: + aws_http_message_release(handshake); + return NULL; +} + +static int s_websocket_connect(struct aws_secure_tunnel *secure_tunnel) { + AWS_ASSERT(secure_tunnel); + + struct aws_http_message *handshake = s_new_handshake_request(secure_tunnel); + if (handshake == NULL) { + goto error; + } + + AWS_LOGF_TRACE( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: Transforming websocket handshake request.", (void *)secure_tunnel); + + struct aws_secure_tunnel_websocket_transform_complete_task *task = + aws_mem_calloc(secure_tunnel->allocator, 1, sizeof(struct aws_secure_tunnel_websocket_transform_complete_task)); + + aws_task_init( + &task->task, s_websocket_transform_complete_task_fn, (void *)task, "WebsocketHandshakeTransformComplete"); + task->allocator = secure_tunnel->allocator; + task->secure_tunnel = aws_secure_tunnel_acquire(secure_tunnel); + task->error_code = AWS_OP_SUCCESS; + task->handshake = handshake; + + aws_event_loop_schedule_task_now(secure_tunnel->loop, &task->task); + + return AWS_OP_SUCCESS; + +error: + return AWS_OP_ERR; +} + +static void s_reset_ping(struct aws_secure_tunnel *secure_tunnel) { + uint64_t now = (*secure_tunnel->vtable->get_current_time_fn)(); + secure_tunnel->next_ping_time = aws_add_u64_saturating(now, PING_TASK_INTERVAL); + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: next PING scheduled for time %" PRIu64, + (void *)secure_tunnel, + secure_tunnel->next_ping_time); +} + +/********************************************************************************************************************* + * State Related + ********************************************************************************************************************/ + +static void s_aws_secure_tunnel_operational_state_reset( + struct aws_secure_tunnel *secure_tunnel, + int completion_error_code) { + s_complete_operation_list(secure_tunnel, &secure_tunnel->queued_operations, completion_error_code); +} + +static void s_change_current_state_to_stopped(struct aws_secure_tunnel *secure_tunnel) { + secure_tunnel->current_state = AWS_STS_STOPPED; + + s_aws_secure_tunnel_operational_state_reset( + secure_tunnel, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP); + + /* Stop works as a complete session wipe, and so the next time we connect, we want it to be clean */ + s_reset_secure_tunnel(secure_tunnel); + + if (secure_tunnel->config->on_stopped) { + secure_tunnel->config->on_stopped(secure_tunnel->config->user_data); + } +} + +static void s_change_current_state_to_connecting(struct aws_secure_tunnel *secure_tunnel) { + AWS_ASSERT( + secure_tunnel->current_state == AWS_STS_STOPPED || secure_tunnel->current_state == AWS_STS_PENDING_RECONNECT); + + secure_tunnel->current_state = AWS_STS_CONNECTING; + + int result = s_websocket_connect(secure_tunnel); + + if (result) { + int error_code = aws_last_error(); + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to kick off connection with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + + s_change_current_state(secure_tunnel, AWS_STS_PENDING_RECONNECT); + } +} + +static void s_change_current_state_to_connected(struct aws_secure_tunnel *secure_tunnel) { + AWS_FATAL_ASSERT(secure_tunnel->current_state == AWS_STS_CONNECTING); + + secure_tunnel->current_state = AWS_STS_CONNECTED; + secure_tunnel->pending_write_completion = false; + secure_tunnel->reconnect_count = 0; + + /* + * TODO Any rejoin logic can be implemented here. Secure Tunnel does not handle any rejoin state. + * We may opt to send disconnects to existing non-zero stream IDs to notify that the server has reconnected. + */ + + s_reset_ping(secure_tunnel); +} + +static void s_change_current_state_to_clean_disconnect(struct aws_secure_tunnel *secure_tunnel) { + AWS_FATAL_ASSERT(secure_tunnel->current_state == AWS_STS_CONNECTED); + + secure_tunnel->current_state = AWS_STS_CLEAN_DISCONNECT; +} + +static void s_change_current_state_to_websocket_shutdown(struct aws_secure_tunnel *secure_tunnel) { + enum aws_secure_tunnel_state current_state = secure_tunnel->current_state; + AWS_FATAL_ASSERT( + current_state == AWS_STS_CONNECTING || current_state == AWS_STS_CONNECTED || + current_state == AWS_STS_CLEAN_DISCONNECT); + secure_tunnel->current_state = AWS_STS_WEBSOCKET_SHUTDOWN; + + if (secure_tunnel->websocket) { + secure_tunnel->vtable->aws_websocket_close_fn(secure_tunnel->websocket, false); + } else { + s_on_websocket_shutdown(secure_tunnel->websocket, AWS_ERROR_UNKNOWN, secure_tunnel); + } +} + +static void s_update_reconnect_delay_for_pending_reconnect(struct aws_secure_tunnel *secure_tunnel) { + + uint64_t delay_ms = MIN_RECONNECT_DELAY_MS; + delay_ms = delay_ms << (int)secure_tunnel->reconnect_count; + + delay_ms = aws_min_u64(delay_ms, MAX_RECONNECT_DELAY_MS); + uint64_t now = (*secure_tunnel->vtable->get_current_time_fn)(); + + secure_tunnel->next_reconnect_time_ns = + aws_add_u64_saturating(now, aws_timestamp_convert(delay_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL)); + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: next connection attempt in %" PRIu64 " milliseconds", + (void *)secure_tunnel, + delay_ms); + + secure_tunnel->reconnect_count++; +} + +static void s_change_current_state_to_pending_reconnect(struct aws_secure_tunnel *secure_tunnel) { + secure_tunnel->current_state = AWS_STS_PENDING_RECONNECT; + + s_update_reconnect_delay_for_pending_reconnect(secure_tunnel); +} + +static void s_change_current_state_to_terminated(struct aws_secure_tunnel *secure_tunnel) { + secure_tunnel->current_state = AWS_STS_TERMINATED; + + s_secure_tunnel_final_destroy(secure_tunnel); +} + +static void s_change_current_state(struct aws_secure_tunnel *secure_tunnel, enum aws_secure_tunnel_state next_state) { + AWS_ASSERT(next_state != secure_tunnel->current_state); + if (next_state == secure_tunnel->current_state) { + return; + } + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: switching current state from %s to %s", + (void *)secure_tunnel, + aws_secure_tunnel_state_to_c_string(secure_tunnel->current_state), + aws_secure_tunnel_state_to_c_string(next_state)); + + switch (next_state) { + case AWS_STS_STOPPED: + s_change_current_state_to_stopped(secure_tunnel); + break; + case AWS_STS_CONNECTING: + s_change_current_state_to_connecting(secure_tunnel); + break; + case AWS_STS_CONNECTED: + s_change_current_state_to_connected(secure_tunnel); + break; + case AWS_STS_CLEAN_DISCONNECT: + s_change_current_state_to_clean_disconnect(secure_tunnel); + break; + case AWS_STS_WEBSOCKET_SHUTDOWN: + s_change_current_state_to_websocket_shutdown(secure_tunnel); + break; + case AWS_STS_PENDING_RECONNECT: + s_change_current_state_to_pending_reconnect(secure_tunnel); + break; + case AWS_STS_TERMINATED: + s_change_current_state_to_terminated(secure_tunnel); + return; + } + + s_reevaluate_service_task(secure_tunnel); +} + +static bool s_is_valid_desired_state(enum aws_secure_tunnel_state desired_state) { + switch (desired_state) { + case AWS_STS_STOPPED: + case AWS_STS_CONNECTED: + case AWS_STS_TERMINATED: + return true; + default: + return false; + } +} + +struct aws_secure_tunnel_change_desired_state_task { + struct aws_task task; + struct aws_allocator *allocator; + struct aws_secure_tunnel *secure_tunnel; + enum aws_secure_tunnel_state desired_state; +}; + +static void s_change_state_task_fn(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + + struct aws_secure_tunnel_change_desired_state_task *change_state_task = arg; + struct aws_secure_tunnel *secure_tunnel = change_state_task->secure_tunnel; + enum aws_secure_tunnel_state desired_state = change_state_task->desired_state; + if (status != AWS_TASK_STATUS_RUN_READY) { + goto done; + } + + if (secure_tunnel->desired_state != desired_state) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: changing desired secure_tunnel state from %s to %s", + (void *)secure_tunnel, + aws_secure_tunnel_state_to_c_string(secure_tunnel->desired_state), + aws_secure_tunnel_state_to_c_string(desired_state)); + + secure_tunnel->desired_state = desired_state; + + s_reevaluate_service_task(secure_tunnel); + } + +done: + + if (desired_state != AWS_STS_TERMINATED) { + aws_secure_tunnel_release(secure_tunnel); + } + + aws_mem_release(change_state_task->allocator, change_state_task); +} + +static struct aws_secure_tunnel_change_desired_state_task *s_aws_secure_tunnel_change_desired_state_task_new( + struct aws_allocator *allocator, + struct aws_secure_tunnel *secure_tunnel, + enum aws_secure_tunnel_state desired_state) { + + struct aws_secure_tunnel_change_desired_state_task *change_state_task = + aws_mem_calloc(allocator, 1, sizeof(struct aws_secure_tunnel_change_desired_state_task)); + if (change_state_task == NULL) { + return NULL; + } + + aws_task_init(&change_state_task->task, s_change_state_task_fn, (void *)change_state_task, "ChangeStateTask"); + change_state_task->allocator = secure_tunnel->allocator; + change_state_task->secure_tunnel = + (desired_state == AWS_STS_TERMINATED) ? secure_tunnel : aws_secure_tunnel_acquire(secure_tunnel); + change_state_task->desired_state = desired_state; + + return change_state_task; +} + +static int s_aws_secure_tunnel_change_desired_state( + struct aws_secure_tunnel *secure_tunnel, + enum aws_secure_tunnel_state desired_state) { + AWS_FATAL_ASSERT(secure_tunnel != NULL); + AWS_FATAL_ASSERT(secure_tunnel->loop != NULL); + + if (!s_is_valid_desired_state(desired_state)) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: invalid desired state argument %d(%s)", + (void *)secure_tunnel, + (int)desired_state, + aws_secure_tunnel_state_to_c_string(desired_state)); + + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + struct aws_secure_tunnel_change_desired_state_task *task = + s_aws_secure_tunnel_change_desired_state_task_new(secure_tunnel->allocator, secure_tunnel, desired_state); + + if (task == NULL) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to create change desired state task", + (void *)secure_tunnel); + return AWS_OP_ERR; + } + + aws_event_loop_schedule_task_now(secure_tunnel->loop, &task->task); + + return AWS_OP_SUCCESS; +} + +/********************************************************************************************************************* + * vtable functions + ********************************************************************************************************************/ + +static uint64_t s_aws_high_res_clock_get_ticks_proxy(void) { + uint64_t current_time = 0; + AWS_FATAL_ASSERT(aws_high_res_clock_get_ticks(¤t_time) == AWS_OP_SUCCESS); + + return current_time; +} + +static struct aws_secure_tunnel_vtable s_default_secure_tunnel_vtable = { + .get_current_time_fn = s_aws_high_res_clock_get_ticks_proxy, + .aws_websocket_client_connect_fn = aws_websocket_client_connect, + .aws_websocket_send_frame_fn = aws_websocket_send_frame, + .aws_websocket_release_fn = aws_websocket_release, + .aws_websocket_close_fn = aws_websocket_close, + + .vtable_user_data = NULL, +}; + +void aws_secure_tunnel_set_vtable( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_vtable *vtable) { + secure_tunnel->vtable = vtable; +} + +const struct aws_secure_tunnel_vtable *aws_secure_tunnel_get_default_vtable(void) { + return &s_default_secure_tunnel_vtable; +} + +/********************************************************************************************************************* + * Operations + ********************************************************************************************************************/ + +static void s_complete_operation( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_operation *operation, + int error_code, + const void *view) { + (void)secure_tunnel; + + aws_secure_tunnel_operation_complete(operation, error_code, view); + aws_secure_tunnel_operation_release(operation); +} + +static void s_complete_operation_list( + struct aws_secure_tunnel *secure_tunnel, + struct aws_linked_list *operation_list, + int error_code) { + + struct aws_linked_list_node *node = aws_linked_list_begin(operation_list); + while (node != aws_linked_list_end(operation_list)) { + struct aws_secure_tunnel_operation *operation = + AWS_CONTAINER_OF(node, struct aws_secure_tunnel_operation, node); + + node = aws_linked_list_next(node); + + s_complete_operation(secure_tunnel, operation, error_code, NULL); + } + + /* we've released everything, so reset the list to empty */ + aws_linked_list_init(operation_list); +} + +/* + * Check whether secure tunnel currently has work left to do based on its current state + */ +static bool s_aws_secure_tunnel_has_pending_operational_work(const struct aws_secure_tunnel *secure_tunnel) { + if (aws_linked_list_empty(&secure_tunnel->queued_operations)) { + return false; + } + + struct aws_linked_list_node *next_operation_node = aws_linked_list_front(&secure_tunnel->queued_operations); + struct aws_secure_tunnel_operation *next_operation = + AWS_CONTAINER_OF(next_operation_node, struct aws_secure_tunnel_operation, node); + + switch (secure_tunnel->current_state) { + case AWS_STS_CLEAN_DISCONNECT: + /* Except for finishing the current operation, only allowed to send STREAM RESET messages in this state + */ + return next_operation->operation_type == AWS_STOT_STREAM_RESET; + + case AWS_STS_CONNECTED: + return true; + + default: + return false; + } +} + +static uint64_t s_aws_secure_tunnel_compute_operational_state_service_time( + struct aws_secure_tunnel *secure_tunnel, + uint64_t now) { + + /* If a message is in transit down the WebSocket, then wait for it to complete */ + if (secure_tunnel->pending_write_completion) { + return 0; + } + + /* If we're in the middle of something, keep going */ + if (secure_tunnel->current_operation != NULL) { + return now; + } + + /* If nothing is queued, there's nothing to do */ + if (!s_aws_secure_tunnel_has_pending_operational_work(secure_tunnel)) { + return 0; + } + + AWS_FATAL_ASSERT(!aws_linked_list_empty(&secure_tunnel->queued_operations)); + + struct aws_linked_list_node *next_operation_node = aws_linked_list_front(&secure_tunnel->queued_operations); + struct aws_secure_tunnel_operation *next_operation = + AWS_CONTAINER_OF(next_operation_node, struct aws_secure_tunnel_operation, node); + + AWS_FATAL_ASSERT(next_operation != NULL); + + /* now unless outside of allowed states */ + switch (secure_tunnel->current_state) { + case AWS_STS_CLEAN_DISCONNECT: + case AWS_STS_CONNECTED: + return now; + + default: + /* no outbound traffic is allowed outside of the above states */ + return 0; + } +} + +static bool s_aws_secure_tunnel_should_service_operational_state( + struct aws_secure_tunnel *secure_tunnel, + uint64_t now) { + return now == s_aws_secure_tunnel_compute_operational_state_service_time(secure_tunnel, now); +} + +int aws_secure_tunnel_service_operational_state(struct aws_secure_tunnel *secure_tunnel) { + const struct aws_secure_tunnel_vtable *vtable = secure_tunnel->vtable; + uint64_t now = (*vtable->get_current_time_fn)(); + + /* Should we write data? */ + bool should_service = s_aws_secure_tunnel_should_service_operational_state(secure_tunnel, now); + if (!should_service) { + return AWS_OP_SUCCESS; + } + + int operational_error_code = AWS_ERROR_SUCCESS; + + do { + /* if no current operation, pull one in and setup encode */ + if (secure_tunnel->current_operation == NULL) { + /* + * Loop through queued operations until we run out or find a good one. + */ + struct aws_secure_tunnel_operation *next_operation = NULL; + + if (!aws_linked_list_empty(&secure_tunnel->queued_operations)) { + struct aws_linked_list_node *next_operation_node = + aws_linked_list_pop_front(&secure_tunnel->queued_operations); + + next_operation = AWS_CONTAINER_OF(next_operation_node, struct aws_secure_tunnel_operation, node); + + secure_tunnel->current_operation = next_operation; + } + } + + struct aws_secure_tunnel_operation *current_operation = secure_tunnel->current_operation; + if (current_operation == NULL) { + break; + } + int error_code = AWS_OP_SUCCESS; + + switch (current_operation->operation_type) { + case AWS_STOT_PING:; + /* + * TODO Currently, pings are sent to keep the websocket alive but we do not receive responses from the + * secure tunnel service until a src is also connected. This is a known bug that is in their + * backlog. Once it is fixed, we should implement ping timeout checks to determine whether we are + * still connected to the secure tunnel through WebSocket. + */ + struct aws_websocket_send_frame_options frame_options; + AWS_ZERO_STRUCT(frame_options); + frame_options.opcode = AWS_WEBSOCKET_OPCODE_PING; + frame_options.fin = true; + secure_tunnel->vtable->aws_websocket_send_frame_fn(secure_tunnel->websocket, &frame_options); + + break; + case AWS_STOT_MESSAGE: + /* If a data message attempts to be sent on an unopen stream, discard it. */ + if ((*current_operation->vtable->aws_secure_tunnel_operation_assign_stream_id_fn)( + current_operation, secure_tunnel)) { + + error_code = aws_last_error(); + + if (current_operation->message_view->service_id) { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to assign service id '" PRInSTR + "' DATA message a stream id with error %d(%s)", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*current_operation->message_view->service_id), + error_code, + aws_error_debug_str(error_code)); + } else { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to assign V1 DATA message a stream id with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + } + } else { + /* Send the Data message through the WebSocket */ + if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { + error_code = aws_last_error(); + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send DATA message with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + } + aws_secure_tunnel_message_view_log(current_operation->message_view, AWS_LL_DEBUG); + } + + break; + + case AWS_STOT_STREAM_START: + if ((*current_operation->vtable->aws_secure_tunnel_operation_set_next_stream_id_fn)( + current_operation, secure_tunnel)) { + error_code = aws_last_error(); + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send STREAM START message with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + } else { + /* Send the Stream Start message through the WebSocket */ + if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { + error_code = aws_last_error(); + } + aws_secure_tunnel_message_view_log(current_operation->message_view, AWS_LL_DEBUG); + } + break; + + case AWS_STOT_STREAM_RESET: + + if ((*current_operation->vtable->aws_secure_tunnel_operation_assign_stream_id_fn)( + current_operation, secure_tunnel)) { + error_code = aws_last_error(); + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send STREAM RESET message with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + } else { + /* Send the Stream Reset message through the WebSocket */ + if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { + error_code = aws_last_error(); + } else { + s_aws_secure_tunnel_set_stream_id( + secure_tunnel, current_operation->message_view->service_id, INVALID_STREAM_ID); + } + aws_secure_tunnel_message_view_log(current_operation->message_view, AWS_LL_DEBUG); + } + + break; + + case AWS_STOT_NONE: + break; + } + + s_complete_operation(secure_tunnel, current_operation, AWS_OP_SUCCESS, NULL); + secure_tunnel->current_operation = NULL; + + now = (*vtable->get_current_time_fn)(); + should_service = s_aws_secure_tunnel_should_service_operational_state(secure_tunnel, now); + } while (should_service); + + if (operational_error_code != AWS_ERROR_SUCCESS) { + return aws_raise_error(operational_error_code); + } + + return AWS_OP_SUCCESS; +} + +void aws_secure_tunnel_operational_state_clean_up(struct aws_secure_tunnel *secure_tunnel) { + AWS_ASSERT(secure_tunnel->current_operation == NULL); + + s_aws_secure_tunnel_operational_state_reset(secure_tunnel, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_TERMINATED); +} + +static void s_enqueue_operation_back( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_operation *operation) { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: enqueuing %s operation to back", + (void *)secure_tunnel, + aws_secure_tunnel_operation_type_to_c_string(operation->operation_type)); + + aws_linked_list_push_back(&secure_tunnel->queued_operations, &operation->node); + + s_reevaluate_service_task(secure_tunnel); +} + +static void s_enqueue_operation_front( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_operation *operation) { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: enqueuing %s operation to front", + (void *)secure_tunnel, + aws_secure_tunnel_operation_type_to_c_string(operation->operation_type)); + + aws_linked_list_push_front(&secure_tunnel->queued_operations, &operation->node); + + s_reevaluate_service_task(secure_tunnel); +} + +struct aws_secure_tunnel_submit_operation_task { + struct aws_task task; + struct aws_allocator *allocator; + struct aws_secure_tunnel *secure_tunnel; + struct aws_secure_tunnel_operation *operation; +}; + +static void s_secure_tunnel_submit_operation_task_fn(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + + int completion_error_code = AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_SECURE_TUNNEL_TERMINATED; + struct aws_secure_tunnel_submit_operation_task *submit_operation_task = arg; + + /* + * Take a ref to the operation that represents the secure tunnel taking ownership + * If we subsequently reject it (task cancel), then the operation completion + * will undo this ref acquisition. + */ + aws_secure_tunnel_operation_acquire(submit_operation_task->operation); + + if (status != AWS_TASK_STATUS_RUN_READY) { + goto error; + } + + /* + * If we're offline fail it immediately. + */ + struct aws_secure_tunnel *secure_tunnel = submit_operation_task->secure_tunnel; + if (secure_tunnel->current_state != AWS_STS_CONNECTED) { + completion_error_code = AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_FAILED_DUE_TO_DISCONNECTION; + goto error; + } + + s_enqueue_operation_back(submit_operation_task->secure_tunnel, submit_operation_task->operation); + + goto done; + +error: + s_complete_operation(NULL, submit_operation_task->operation, completion_error_code, NULL); + +done: + aws_secure_tunnel_operation_release(submit_operation_task->operation); + aws_secure_tunnel_release(submit_operation_task->secure_tunnel); + + aws_mem_release(submit_operation_task->allocator, submit_operation_task); +} + +static int s_submit_operation(struct aws_secure_tunnel *secure_tunnel, struct aws_secure_tunnel_operation *operation) { + struct aws_secure_tunnel_submit_operation_task *submit_task = + aws_mem_calloc(secure_tunnel->allocator, 1, sizeof(struct aws_secure_tunnel_submit_operation_task)); + if (submit_task == NULL) { + return AWS_OP_ERR; + } + + aws_task_init( + &submit_task->task, s_secure_tunnel_submit_operation_task_fn, submit_task, "SecureTunnelSubmitOperation"); + submit_task->allocator = secure_tunnel->allocator; + submit_task->secure_tunnel = aws_secure_tunnel_acquire(secure_tunnel); + submit_task->operation = operation; + + aws_event_loop_schedule_task_now(secure_tunnel->loop, &submit_task->task); + + return AWS_OP_SUCCESS; +} + +/********************************************************************************************************************* + * Service Timing + ********************************************************************************************************************/ + +static uint64_t s_min_non_0_64(uint64_t a, uint64_t b) { + if (a == 0) { + return b; + } + + if (b == 0) { + return a; + } + + return aws_min_u64(a, b); +} + +/* + * next_service_time == 0 means to not service the secure tunnel, i.e. a state that only cares about external events + * + * This includes connecting and channel shutdown. Terminated is also included, but it's a state that only exists + * instantaneously before final destruction. + */ +static uint64_t s_compute_next_service_time_secure_tunnel_stopped( + struct aws_secure_tunnel *secure_tunnel, + uint64_t now) { + /* have we been told to connect or terminate? */ + if (secure_tunnel->desired_state != AWS_STS_STOPPED) { + return now; + } + + return 0; +} + +static uint64_t s_compute_next_service_time_secure_tunnel_connecting( + struct aws_secure_tunnel *secure_tunnel, + uint64_t now) { + (void)secure_tunnel; + (void)now; + + return 0; +} + +static uint64_t s_compute_next_service_time_secure_tunnel_connected( + struct aws_secure_tunnel *secure_tunnel, + uint64_t now) { + /* TODO check against ping timeout once pong is implemented by secure tunnel service */ + uint64_t next_service_time = secure_tunnel->next_ping_time; + + if (secure_tunnel->desired_state != AWS_STS_CONNECTED) { + next_service_time = now; + } + + uint64_t operation_processing_time = s_aws_secure_tunnel_compute_operational_state_service_time(secure_tunnel, now); + + next_service_time = s_min_non_0_64(operation_processing_time, next_service_time); + + return next_service_time; +} + +static uint64_t s_compute_next_service_time_secure_tunnel_clean_disconnect( + struct aws_secure_tunnel *secure_tunnel, + uint64_t now) { + return s_aws_secure_tunnel_compute_operational_state_service_time(secure_tunnel, now); +} + +static uint64_t s_compute_next_service_time_secure_tunnel_websocket_shutdown( + struct aws_secure_tunnel *secure_tunnel, + uint64_t now) { + (void)secure_tunnel; + (void)now; + + return 0; +} + +static uint64_t s_compute_next_service_time_secure_tunnel_pending_reconnect( + struct aws_secure_tunnel *secure_tunnel, + uint64_t now) { + if (secure_tunnel->desired_state != AWS_STS_CONNECTED) { + return now; + } + + return secure_tunnel->next_reconnect_time_ns; +} + +static uint64_t s_compute_next_service_time_by_current_state(struct aws_secure_tunnel *secure_tunnel, uint64_t now) { + + switch (secure_tunnel->current_state) { + case AWS_STS_STOPPED: + return s_compute_next_service_time_secure_tunnel_stopped(secure_tunnel, now); + case AWS_STS_CONNECTING: + return s_compute_next_service_time_secure_tunnel_connecting(secure_tunnel, now); + case AWS_STS_CONNECTED: + return s_compute_next_service_time_secure_tunnel_connected(secure_tunnel, now); + case AWS_STS_CLEAN_DISCONNECT: + return s_compute_next_service_time_secure_tunnel_clean_disconnect(secure_tunnel, now); + case AWS_STS_WEBSOCKET_SHUTDOWN: + return s_compute_next_service_time_secure_tunnel_websocket_shutdown(secure_tunnel, now); + case AWS_STS_PENDING_RECONNECT: + return s_compute_next_service_time_secure_tunnel_pending_reconnect(secure_tunnel, now); + case AWS_STS_TERMINATED: + return 0; + } + + return 0; +} + +static void s_reevaluate_service_task(struct aws_secure_tunnel *secure_tunnel) { + /* + * This causes the secure tunnel to only reevaluate service schedule time at the end of the service call or in + * a callback from an external event. + */ + if (secure_tunnel->in_service) { + return; + } + + uint64_t now = (*secure_tunnel->vtable->get_current_time_fn)(); + uint64_t next_service_time = s_compute_next_service_time_by_current_state(secure_tunnel, now); + + /* + * This catches both the case when there's an existing service schedule and we either want to not + * perform it (next_service_time == 0) or need to run service at a different time than the current scheduled + * time. + */ + if (next_service_time != secure_tunnel->next_service_task_run_time && + secure_tunnel->next_service_task_run_time > 0) { + aws_event_loop_cancel_task(secure_tunnel->loop, &secure_tunnel->service_task); + secure_tunnel->next_service_task_run_time = 0; + + AWS_LOGF_TRACE( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: cancelling previously scheduled service task", + (void *)secure_tunnel); + } + + if (next_service_time > 0 && (next_service_time < secure_tunnel->next_service_task_run_time || + secure_tunnel->next_service_task_run_time == 0)) { + aws_event_loop_schedule_task_future(secure_tunnel->loop, &secure_tunnel->service_task, next_service_time); + + AWS_LOGF_TRACE( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: scheduled service task for time %" PRIu64, + (void *)secure_tunnel, + next_service_time); + } + + secure_tunnel->next_service_task_run_time = next_service_time; +} + +/********************************************************************************************************************* + * Update Loop + ********************************************************************************************************************/ + +static int s_aws_secure_tunnel_queue_ping(struct aws_secure_tunnel *secure_tunnel) { + s_reset_ping(secure_tunnel); + + AWS_LOGF_DEBUG(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: queuing PING", (void *)secure_tunnel); + + struct aws_secure_tunnel_operation_pingreq *pingreq_op = + aws_secure_tunnel_operation_pingreq_new(secure_tunnel->allocator); + s_enqueue_operation_front(secure_tunnel, &pingreq_op->base); + + return AWS_OP_SUCCESS; +} + +static bool s_service_state_stopped(struct aws_secure_tunnel *secure_tunnel) { + enum aws_secure_tunnel_state desired_state = secure_tunnel->desired_state; + if (desired_state == AWS_STS_CONNECTED) { + s_change_current_state(secure_tunnel, AWS_STS_CONNECTING); + } else if (desired_state == AWS_STS_TERMINATED) { + s_change_current_state(secure_tunnel, AWS_STS_TERMINATED); + return true; + } + return false; +} + +static void s_service_state_connecting(struct aws_secure_tunnel *secure_tunnel, uint64_t now) { + (void)secure_tunnel; + (void)now; +} + +static void s_service_state_connected(struct aws_secure_tunnel *secure_tunnel, uint64_t now) { + enum aws_secure_tunnel_state desired_state = secure_tunnel->desired_state; + if (desired_state != AWS_STS_CONNECTED) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: channel shutdown due to user Stop request", + (void *)secure_tunnel); + s_secure_tunnel_shutdown_websocket(secure_tunnel, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP); + return; + } + + if (now >= secure_tunnel->next_ping_time) { + if (s_aws_secure_tunnel_queue_ping(secure_tunnel)) { + int error_code = aws_last_error(); + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to queue PINGREQ with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + s_secure_tunnel_shutdown_websocket(secure_tunnel, error_code); + return; + } + } + + if (aws_secure_tunnel_service_operational_state(secure_tunnel)) { + int error_code = aws_last_error(); + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to service CONNECTED operation queue with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + s_secure_tunnel_shutdown_websocket(secure_tunnel, error_code); + return; + } +} + +static void s_service_state_clean_disconnect(struct aws_secure_tunnel *secure_tunnel, uint64_t now) { + (void)now; + if (aws_secure_tunnel_service_operational_state(secure_tunnel)) { + int error_code = aws_last_error(); + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to service CLEAN_DISCONNECT operation queue with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + s_secure_tunnel_shutdown_websocket(secure_tunnel, error_code); + return; + } +} + +static void s_service_state_pending_reconnect(struct aws_secure_tunnel *secure_tunnel, uint64_t now) { + if (secure_tunnel->desired_state != AWS_STS_CONNECTED) { + s_change_current_state(secure_tunnel, AWS_STS_STOPPED); + return; + } + + if (now >= secure_tunnel->next_reconnect_time_ns) { + s_change_current_state(secure_tunnel, AWS_STS_CONNECTING); + return; + } +} + +static void s_secure_tunnel_service_task_fn(struct aws_task *task, void *arg, enum aws_task_status status) { + (void)task; + if (status != AWS_TASK_STATUS_RUN_READY) { + return; + } + + struct aws_secure_tunnel *secure_tunnel = arg; + secure_tunnel->next_service_task_run_time = 0; + secure_tunnel->in_service = true; + + uint64_t now = (*secure_tunnel->vtable->get_current_time_fn)(); + bool terminated = false; + switch (secure_tunnel->current_state) { + case AWS_STS_STOPPED: + terminated = s_service_state_stopped(secure_tunnel); + break; + case AWS_STS_CONNECTING: + s_service_state_connecting(secure_tunnel, now); + break; + case AWS_STS_CONNECTED: + s_service_state_connected(secure_tunnel, now); + break; + case AWS_STS_CLEAN_DISCONNECT: + s_service_state_clean_disconnect(secure_tunnel, now); + break; + case AWS_STS_PENDING_RECONNECT: + s_service_state_pending_reconnect(secure_tunnel, now); + break; + default: + break; + } + + /* + * We can only enter the terminated state from stopped. If we do so, the secure tunnel memory is now freed and + * we will crash if we access anything anymore. + */ + if (terminated) { + return; + } + + /* we're not scheduled anymore, reschedule as needed */ + secure_tunnel->in_service = false; + s_reevaluate_service_task(secure_tunnel); +} + +/********************************************************************************************************************* + * API Calls + ********************************************************************************************************************/ + +struct aws_secure_tunnel *aws_secure_tunnel_new( + struct aws_allocator *allocator, + const struct aws_secure_tunnel_options *options) { + AWS_FATAL_ASSERT(options != NULL); + AWS_FATAL_ASSERT(allocator != NULL); + + struct aws_secure_tunnel *secure_tunnel = aws_mem_calloc(allocator, 1, sizeof(struct aws_secure_tunnel)); + if (secure_tunnel == NULL) { + return NULL; + } + + aws_task_init(&secure_tunnel->service_task, s_secure_tunnel_service_task_fn, secure_tunnel, "SecureTunnelService"); + + secure_tunnel->allocator = allocator; + secure_tunnel->vtable = &s_default_secure_tunnel_vtable; + + aws_ref_count_init(&secure_tunnel->ref_count, secure_tunnel, s_on_secure_tunnel_zero_ref_count); + + aws_linked_list_init(&secure_tunnel->queued_operations); + secure_tunnel->current_operation = NULL; + + /* store options */ + secure_tunnel->config = aws_secure_tunnel_options_storage_new(allocator, options); + if (secure_tunnel->config == NULL) { + goto error; + } + + /* all secure tunnel activity will take place on this event loop */ + secure_tunnel->loop = aws_event_loop_group_get_next_loop(secure_tunnel->config->bootstrap->event_loop_group); + if (secure_tunnel->loop == NULL) { + goto error; + } + + secure_tunnel->desired_state = AWS_STS_STOPPED; + secure_tunnel->current_state = AWS_STS_STOPPED; + + /* tls setup */ + struct aws_tls_ctx_options tls_ctx_opt; + AWS_ZERO_STRUCT(tls_ctx_opt); + aws_tls_ctx_options_init_default_client(&tls_ctx_opt, secure_tunnel->allocator); if (options->root_ca != NULL) { if (aws_tls_ctx_options_override_default_trust_store_from_path(&tls_ctx_opt, NULL, options->root_ca)) { @@ -633,7 +1765,7 @@ struct aws_secure_tunnel *aws_secure_tunnel_new(const struct aws_secure_tunnel_o } } - secure_tunnel->tls_ctx = aws_tls_client_ctx_new(options->allocator, &tls_ctx_opt); + secure_tunnel->tls_ctx = aws_tls_client_ctx_new(allocator, &tls_ctx_opt); if (secure_tunnel->tls_ctx == NULL) { goto error; } @@ -641,30 +1773,23 @@ struct aws_secure_tunnel *aws_secure_tunnel_new(const struct aws_secure_tunnel_o /* tls_connection_options */ aws_tls_connection_options_init_from_ctx(&secure_tunnel->tls_con_opt, secure_tunnel->tls_ctx); if (aws_tls_connection_options_set_server_name( - &secure_tunnel->tls_con_opt, options->allocator, (struct aws_byte_cursor *)&options->endpoint_host)) { + &secure_tunnel->tls_con_opt, allocator, (struct aws_byte_cursor *)&options->endpoint_host)) { goto error; } aws_tls_ctx_options_clean_up(&tls_ctx_opt); - /* Setup vtables here. */ - secure_tunnel->vtable.connect = s_secure_tunneling_connect; - secure_tunnel->vtable.close = s_secure_tunneling_close; - secure_tunnel->vtable.send_data = s_secure_tunneling_send_data; - secure_tunnel->vtable.send_stream_start = s_secure_tunneling_send_stream_start; - secure_tunnel->vtable.send_stream_reset = s_secure_tunneling_send_stream_reset; + /* Connection reset */ + secure_tunnel->config->stream_id = INVALID_STREAM_ID; - secure_tunnel->websocket_vtable.client_connect = aws_websocket_client_connect; - secure_tunnel->websocket_vtable.send_frame = aws_websocket_send_frame; - secure_tunnel->websocket_vtable.close = aws_websocket_close; - secure_tunnel->websocket_vtable.release = aws_websocket_release; + aws_hash_table_foreach(&secure_tunnel->config->service_ids, s_reset_service_id, NULL); secure_tunnel->handshake_request = NULL; - secure_tunnel->stream_id = INVALID_STREAM_ID; secure_tunnel->websocket = NULL; - /* TODO: Release this buffer when there is no data to hold */ - aws_byte_buf_init(&secure_tunnel->received_data, options->allocator, MAX_WEBSOCKET_PAYLOAD); + aws_byte_buf_init(&secure_tunnel->received_data, allocator, MAX_WEBSOCKET_PAYLOAD); + + aws_secure_tunnel_options_storage_log(secure_tunnel->config, AWS_LL_DEBUG); return secure_tunnel; @@ -675,54 +1800,121 @@ struct aws_secure_tunnel *aws_secure_tunnel_new(const struct aws_secure_tunnel_o } struct aws_secure_tunnel *aws_secure_tunnel_acquire(struct aws_secure_tunnel *secure_tunnel) { - aws_ref_count_acquire(&secure_tunnel->ref_count); + if (secure_tunnel != NULL) { + aws_ref_count_acquire(&secure_tunnel->ref_count); + } return secure_tunnel; } -void aws_secure_tunnel_release(struct aws_secure_tunnel *secure_tunnel) { - if (secure_tunnel == NULL) { - return; +struct aws_secure_tunnel *aws_secure_tunnel_release(struct aws_secure_tunnel *secure_tunnel) { + if (secure_tunnel != NULL) { + aws_ref_count_release(&secure_tunnel->ref_count); } - aws_ref_count_release(&secure_tunnel->ref_count); + + return NULL; } -static void s_secure_tunnel_destroy(void *user_data) { - struct aws_secure_tunnel *secure_tunnel = user_data; +int aws_secure_tunnel_start(struct aws_secure_tunnel *secure_tunnel) { + return s_aws_secure_tunnel_change_desired_state(secure_tunnel, AWS_STS_CONNECTED); +} - aws_secure_tunneling_on_termination_complete_fn *on_termination_complete = NULL; - void *termination_complete_user_data = NULL; - if (secure_tunnel->options != NULL) { - on_termination_complete = secure_tunnel->options->on_termination_complete; - termination_complete_user_data = secure_tunnel->options->user_data; +int aws_secure_tunnel_stop(struct aws_secure_tunnel *secure_tunnel) { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: Stopping secure tunnel immediately", (void *)secure_tunnel); + return s_aws_secure_tunnel_change_desired_state(secure_tunnel, AWS_STS_STOPPED); +} + +int aws_secure_tunnel_send_message( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options) { + AWS_PRECONDITION(secure_tunnel != NULL); + AWS_PRECONDITION(message_options != NULL); + + struct aws_secure_tunnel_operation_message *message_op = aws_secure_tunnel_operation_message_new( + secure_tunnel->allocator, secure_tunnel, message_options, AWS_STOT_MESSAGE); + + if (message_op == NULL) { + return AWS_OP_ERR; } - aws_secure_tunnel_options_storage_destroy(secure_tunnel->options_storage); - aws_byte_buf_clean_up(&secure_tunnel->received_data); - aws_tls_connection_options_clean_up(&secure_tunnel->tls_con_opt); - aws_tls_ctx_release(secure_tunnel->tls_ctx); - aws_mem_release(secure_tunnel->alloc, secure_tunnel); + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Submitting MESSAGE operation (%p)", + (void *)secure_tunnel, + (void *)message_op); - if (on_termination_complete != NULL) { - (*on_termination_complete)(termination_complete_user_data); + if (s_submit_operation(secure_tunnel, &message_op->base)) { + goto error; } -} -int aws_secure_tunnel_connect(struct aws_secure_tunnel *secure_tunnel) { - return secure_tunnel->vtable.connect(secure_tunnel); -} + return AWS_OP_SUCCESS; -int aws_secure_tunnel_close(struct aws_secure_tunnel *secure_tunnel) { - return secure_tunnel->vtable.close(secure_tunnel); +error: + aws_secure_tunnel_operation_release(&message_op->base); + return AWS_OP_ERR; } -int aws_secure_tunnel_send_data(struct aws_secure_tunnel *secure_tunnel, const struct aws_byte_cursor *data) { - return secure_tunnel->vtable.send_data(secure_tunnel, data); -} +int aws_secure_tunnel_stream_start( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options) { + AWS_PRECONDITION(secure_tunnel != NULL); + AWS_PRECONDITION(message_options != NULL); + + if (secure_tunnel->config->local_proxy_mode == AWS_SECURE_TUNNELING_DESTINATION_MODE) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Stream Start can only be sent from source mode"); + return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INCORRECT_MODE; + } + + struct aws_secure_tunnel_operation_message *message_op = aws_secure_tunnel_operation_message_new( + secure_tunnel->allocator, secure_tunnel, message_options, AWS_STOT_STREAM_START); + + if (message_op == NULL) { + return AWS_OP_ERR; + } + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Submitting STREAM START operation (%p)", + (void *)secure_tunnel, + (void *)message_op); + + if (s_submit_operation(secure_tunnel, &message_op->base)) { + goto error; + } + + return AWS_OP_SUCCESS; -int aws_secure_tunnel_stream_start(struct aws_secure_tunnel *secure_tunnel) { - return secure_tunnel->vtable.send_stream_start(secure_tunnel); +error: + aws_secure_tunnel_operation_release(&message_op->base); + return AWS_OP_ERR; } -int aws_secure_tunnel_stream_reset(struct aws_secure_tunnel *secure_tunnel) { - return secure_tunnel->vtable.send_stream_reset(secure_tunnel); +int aws_secure_tunnel_stream_reset( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options) { + AWS_PRECONDITION(secure_tunnel != NULL); + AWS_PRECONDITION(message_options != NULL); + + struct aws_secure_tunnel_operation_message *message_op = aws_secure_tunnel_operation_message_new( + secure_tunnel->allocator, secure_tunnel, message_options, AWS_STOT_STREAM_RESET); + + if (message_op == NULL) { + return AWS_OP_ERR; + } + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Submitting STREAM RESET operation (%p)", + (void *)secure_tunnel, + (void *)message_op); + + if (s_submit_operation(secure_tunnel, &message_op->base)) { + goto error; + } + + return AWS_OP_SUCCESS; + +error: + aws_secure_tunnel_operation_release(&message_op->base); + return AWS_OP_ERR; } diff --git a/source/secure_tunneling_operations.c b/source/secure_tunneling_operations.c new file mode 100644 index 00000000..9931805d --- /dev/null +++ b/source/secure_tunneling_operations.c @@ -0,0 +1,722 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define INVALID_STREAM_ID 0 + +/* for the hash table, to destroy elements */ +static void s_destroy_service_id(void *data) { + struct aws_service_id_element *elem = data; + aws_string_destroy(elem->service_id_string); + aws_mem_release(elem->allocator, elem); +} + +struct aws_service_id_element *aws_service_id_element_new( + struct aws_allocator *allocator, + const struct aws_byte_cursor *service_id, + int32_t stream_id) { + AWS_PRECONDITION(allocator != NULL); + AWS_PRECONDITION(service_id != NULL); + + struct aws_service_id_element *elem = aws_mem_calloc(allocator, 1, sizeof(struct aws_service_id_element)); + elem->allocator = allocator; + elem->service_id_string = aws_string_new_from_cursor(allocator, service_id); + if (elem->service_id_string == NULL) { + goto error; + } + elem->service_id_cur = aws_byte_cursor_from_string(elem->service_id_string); + elem->stream_id = stream_id; + + return elem; + +error: + s_destroy_service_id(elem); + return NULL; +} + +/********************************************************************************************************************* + * Operation base + ********************************************************************************************************************/ + +struct aws_secure_tunnel_operation *aws_secure_tunnel_operation_acquire(struct aws_secure_tunnel_operation *operation) { + if (operation == NULL) { + return NULL; + } + + aws_ref_count_acquire(&operation->ref_count); + + return operation; +} + +struct aws_secure_tunnel_operation *aws_secure_tunnel_operation_release(struct aws_secure_tunnel_operation *operation) { + if (operation != NULL) { + aws_ref_count_release(&operation->ref_count); + } + + return NULL; +} + +void aws_secure_tunnel_operation_complete( + struct aws_secure_tunnel_operation *operation, + int error_code, + const void *associated_view) { + AWS_FATAL_ASSERT(operation->vtable != NULL); + + if (operation->vtable->aws_secure_tunnel_operation_completion_fn != NULL) { + (*operation->vtable->aws_secure_tunnel_operation_completion_fn)(operation, error_code, associated_view); + } +} + +void aws_secure_tunnel_operation_assign_stream_id( + struct aws_secure_tunnel_operation *operation, + struct aws_secure_tunnel *secure_tunnel) { + AWS_FATAL_ASSERT(operation->vtable != NULL); + if (operation->vtable->aws_secure_tunnel_operation_assign_stream_id_fn != NULL) { + (*operation->vtable->aws_secure_tunnel_operation_assign_stream_id_fn)(operation, secure_tunnel); + } +} + +static struct aws_secure_tunnel_operation_vtable s_empty_operation_vtable = { + .aws_secure_tunnel_operation_completion_fn = NULL, + .aws_secure_tunnel_operation_assign_stream_id_fn = NULL, + .aws_secure_tunnel_operation_set_next_stream_id_fn = NULL, +}; + +/********************************************************************************************************************* + * Message + ********************************************************************************************************************/ + +int aws_secure_tunnel_message_view_validate(const struct aws_secure_tunnel_message_view *message_view) { + if (message_view == NULL) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "null message options"); + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION); + } + + if (message_view->type == AWS_SECURE_TUNNEL_MT_DATA && message_view->stream_id != 0) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view stream id for DATA MESSAGES must be 0", + (void *)message_view); + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION); + } + + if (message_view->payload != NULL && message_view->payload->len > AWS_IOT_ST_MAX_MESSAGE_SIZE) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view - payload too long", + (void *)message_view); + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION); + } + + return AWS_OP_SUCCESS; +} + +void aws_secure_tunnel_message_view_log( + const struct aws_secure_tunnel_message_view *message_view, + enum aws_log_level level) { + struct aws_logger *log_handle = aws_logger_get_conditional(AWS_LS_IOTDEVICE_SECURE_TUNNELING, level); + if (log_handle == NULL) { + return; + } + + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view type '%s'", + (void *)message_view, + aws_secure_tunnel_message_type_to_c_string(message_view->type)); + + if (message_view->service_id != NULL) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view service_id set to '" PRInSTR "'", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id)); + } else { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view service_id not set", + (void *)message_view); + } + + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view stream_id set to %d", + (void *)message_view, + (int)message_view->stream_id); + + if (message_view->payload != NULL) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view payload set containing %zu bytes", + (void *)message_view, + message_view->payload->len); + } +} + +static size_t s_aws_secure_tunnel_message_compute_storage_size( + const struct aws_secure_tunnel_message_view *message_view) { + size_t storage_size = message_view->payload == NULL ? 0 : message_view->payload->len; + storage_size += message_view->service_id == NULL ? 0 : message_view->service_id->len; + + return storage_size; +} + +int aws_secure_tunnel_message_storage_init( + struct aws_secure_tunnel_message_storage *message_storage, + struct aws_allocator *allocator, + const struct aws_secure_tunnel_message_view *message_options, + enum aws_secure_tunnel_operation_type type) { + + AWS_ZERO_STRUCT(*message_storage); + size_t storage_capacity = s_aws_secure_tunnel_message_compute_storage_size(message_options); + if (aws_byte_buf_init(&message_storage->storage, allocator, storage_capacity)) { + return AWS_OP_ERR; + } + + struct aws_secure_tunnel_message_view *storage_view = &message_storage->storage_view; + + storage_view->type = message_options->type; + storage_view->ignorable = message_options->ignorable; + storage_view->stream_id = message_options->stream_id; + + switch (type) { + case AWS_STOT_MESSAGE: + storage_view->type = AWS_SECURE_TUNNEL_MT_DATA; + break; + case AWS_STOT_STREAM_START: + storage_view->type = AWS_SECURE_TUNNEL_MT_STREAM_START; + break; + case AWS_STOT_STREAM_RESET: + storage_view->type = AWS_SECURE_TUNNEL_MT_STREAM_RESET; + break; + default: + storage_view->type = AWS_SECURE_TUNNEL_MT_UNKNOWN; + break; + } + + if (message_options->service_id != NULL) { + message_storage->service_id = *message_options->service_id; + if (aws_byte_buf_append_and_update(&message_storage->storage, &message_storage->service_id)) { + return AWS_OP_ERR; + } + storage_view->service_id = &message_storage->service_id; + } + + if (message_options->payload != NULL) { + message_storage->payload = *message_options->payload; + if (aws_byte_buf_append_and_update(&message_storage->storage, &message_storage->payload)) { + return AWS_OP_ERR; + } + storage_view->payload = &message_storage->payload; + } + + return AWS_OP_SUCCESS; +} + +void aws_secure_tunnel_message_storage_clean_up(struct aws_secure_tunnel_message_storage *message_storage) { + aws_byte_buf_clean_up(&message_storage->storage); +} + +/* Sets the stream id on outbound message based on the service id (or lack of for V1) to the current one being used. */ +static int s_aws_secure_tunnel_operation_message_assign_stream_id( + struct aws_secure_tunnel_operation *operation, + struct aws_secure_tunnel *secure_tunnel) { + + struct aws_secure_tunnel_operation_message *message_op = operation->impl; + int32_t stream_id = INVALID_STREAM_ID; + + struct aws_secure_tunnel_message_view *message_view = &message_op->options_storage.storage_view; + + if (message_view->service_id != NULL) { + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, message_view->service_id, &elem); + if (elem == NULL) { + AWS_LOGF_WARN( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: invalid service_id:'" PRInSTR "' attempted to be used with an outbound message", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id)); + stream_id = INVALID_STREAM_ID; + } else { + struct aws_service_id_element *service_id_elem = elem->value; + stream_id = service_id_elem->stream_id; + } + } else { + stream_id = secure_tunnel->config->stream_id; + } + + if (stream_id == INVALID_STREAM_ID) { + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM); + } + + message_op->options_storage.storage_view.stream_id = stream_id; + return AWS_OP_SUCCESS; +} + +/* + * Check the outbound stream start service id (or lack of one for V1) and set the secure tunnel and stream start + * message's stream id to the next value. + */ +static int s_aws_secure_tunnel_operation_message_set_next_stream_id( + struct aws_secure_tunnel_operation *operation, + struct aws_secure_tunnel *secure_tunnel) { + + struct aws_secure_tunnel_operation_message *message_op = operation->impl; + int32_t stream_id = INVALID_STREAM_ID; + + struct aws_secure_tunnel_message_view *message_view = &message_op->options_storage.storage_view; + + if (message_view->service_id != NULL && message_view->service_id->len > 0) { + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, message_view->service_id, &elem); + if (elem == NULL) { + AWS_LOGF_WARN( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: invalid service_id:'" PRInSTR + "' attempted to be used to set next stream id on an outbound message", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id)); + stream_id = INVALID_STREAM_ID; + } else { + struct aws_service_id_element *service_id_elem = elem->value; + stream_id = service_id_elem->stream_id + 1; + + struct aws_service_id_element *replacement_elem = + aws_service_id_element_new(secure_tunnel->allocator, message_view->service_id, stream_id); + aws_hash_table_put( + &secure_tunnel->config->service_ids, &replacement_elem->service_id_cur, replacement_elem, NULL); + } + } else { + stream_id = secure_tunnel->config->stream_id + 1; + secure_tunnel->config->stream_id = stream_id; + } + + if (stream_id == INVALID_STREAM_ID) { + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM); + } + + message_op->options_storage.storage_view.stream_id = stream_id; + + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel service_id '" PRInSTR "' stream_id set to %d", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + stream_id); + + return AWS_OP_SUCCESS; +} + +static struct aws_secure_tunnel_operation_vtable s_message_operation_vtable = { + .aws_secure_tunnel_operation_assign_stream_id_fn = s_aws_secure_tunnel_operation_message_assign_stream_id, + .aws_secure_tunnel_operation_set_next_stream_id_fn = s_aws_secure_tunnel_operation_message_set_next_stream_id, +}; + +static void s_destroy_operation_message(void *object) { + if (object == NULL) { + return; + } + + struct aws_secure_tunnel_operation_message *message_op = object; + + aws_secure_tunnel_message_storage_clean_up(&message_op->options_storage); + + aws_mem_release(message_op->allocator, message_op); +} + +struct aws_secure_tunnel_operation_message *aws_secure_tunnel_operation_message_new( + struct aws_allocator *allocator, + const struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options, + enum aws_secure_tunnel_operation_type type) { + (void)secure_tunnel; + AWS_PRECONDITION(allocator != NULL); + AWS_PRECONDITION(message_options != NULL); + + if (aws_secure_tunnel_message_view_validate(message_options)) { + return NULL; + } + + struct aws_secure_tunnel_operation_message *message_op = + aws_mem_calloc(allocator, 1, sizeof(struct aws_secure_tunnel_operation_message)); + if (message_op == NULL) { + return NULL; + } + + message_op->allocator = allocator; + message_op->base.vtable = &s_message_operation_vtable; + message_op->base.operation_type = type; + aws_ref_count_init(&message_op->base.ref_count, message_op, s_destroy_operation_message); + message_op->base.impl = message_op; + + if (aws_secure_tunnel_message_storage_init(&message_op->options_storage, allocator, message_options, type)) { + goto error; + } + + message_op->base.message_view = &message_op->options_storage.storage_view; + + return message_op; + +error: + + aws_secure_tunnel_operation_release(&message_op->base); + + return NULL; +} + +/********************************************************************************************************************* + * Pingreq + ********************************************************************************************************************/ + +static void s_destroy_operation_pingreq(void *object) { + if (object == NULL) { + return; + } + + struct aws_secure_tunnel_operation_pingreq *pingreq_op = object; + aws_mem_release(pingreq_op->allocator, pingreq_op); +} + +struct aws_secure_tunnel_operation_pingreq *aws_secure_tunnel_operation_pingreq_new(struct aws_allocator *allocator) { + AWS_PRECONDITION(allocator != NULL); + + struct aws_secure_tunnel_operation_pingreq *pingreq_op = + aws_mem_calloc(allocator, 1, sizeof(struct aws_secure_tunnel_operation_pingreq)); + if (pingreq_op == NULL) { + return NULL; + } + + pingreq_op->allocator = allocator; + pingreq_op->base.vtable = &s_empty_operation_vtable; + pingreq_op->base.operation_type = AWS_STOT_PING; + aws_ref_count_init(&pingreq_op->base.ref_count, pingreq_op, s_destroy_operation_pingreq); + pingreq_op->base.impl = pingreq_op; + + return pingreq_op; +} + +/********************************************************************************************************************* + * Secure Tunnel Storage Options + ********************************************************************************************************************/ + +/* + * Validation of options on creation of a new secure tunnel + */ +int aws_secure_tunnel_options_validate(const struct aws_secure_tunnel_options *options) { + AWS_ASSERT(options); + + if (options->bootstrap == NULL) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "bootstrap cannot be NULL"); + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + if (options->socket_options == NULL) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "socket options cannot be NULL"); + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + if (options->access_token.len == 0) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "access token is required"); + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + if (options->endpoint_host.len == 0) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "endpoint host is required"); + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + return AWS_OP_SUCCESS; +} + +void aws_secure_tunnel_options_storage_log( + const struct aws_secure_tunnel_options_storage *options_storage, + enum aws_log_level level) { + struct aws_logger *log_handle = aws_logger_get_conditional(AWS_LS_IOTDEVICE_SECURE_TUNNELING, level); + if (log_handle == NULL) { + return; + } + + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_options_storage host name set to %s", + (void *)options_storage, + aws_string_c_str(options_storage->endpoint_host)); + + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_options_storage bootstrap set to (%p)", + (void *)options_storage, + (void *)options_storage->bootstrap); + + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_options_storage socket options set to: type = %d, domain = %d, connect_timeout_ms = " + "%" PRIu32, + (void *)options_storage, + (int)options_storage->socket_options.type, + (int)options_storage->socket_options.domain, + options_storage->socket_options.connect_timeout_ms); + + if (options_storage->socket_options.keepalive) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_options_storage socket keepalive options set to: keep_alive_interval_sec = " + "%" PRIu16 ", " + "keep_alive_timeout_sec = %" PRIu16 ", keep_alive_max_failed_probes = %" PRIu16, + (void *)options_storage, + options_storage->socket_options.keep_alive_interval_sec, + options_storage->socket_options.keep_alive_timeout_sec, + options_storage->socket_options.keep_alive_max_failed_probes); + } + + if (options_storage->http_proxy_config != NULL) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_options_storage using http proxy:", + (void *)options_storage); + + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_options_storage http proxy host name set to " PRInSTR, + (void *)options_storage, + AWS_BYTE_CURSOR_PRI(options_storage->http_proxy_options.host)); + + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_options_storage http proxy port set to %" PRIu16, + (void *)options_storage, + options_storage->http_proxy_options.port); + + if (options_storage->http_proxy_options.proxy_strategy != NULL) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_options_storage http proxy strategy set to (%p)", + (void *)options_storage, + (void *)options_storage->http_proxy_options.proxy_strategy); + } + } +} + +/* + * Clean up stored secure tunnel config + */ +void aws_secure_tunnel_options_storage_destroy(struct aws_secure_tunnel_options_storage *storage) { + if (storage == NULL) { + return; + } + + aws_client_bootstrap_release(storage->bootstrap); + aws_http_proxy_config_destroy(storage->http_proxy_config); + aws_string_destroy(storage->endpoint_host); + aws_string_destroy(storage->access_token); + aws_string_destroy(storage->client_token); + aws_hash_table_clean_up(&storage->service_ids); + aws_mem_release(storage->allocator, storage); +} + +/* + * Copy and store secure tunnel options + */ +struct aws_secure_tunnel_options_storage *aws_secure_tunnel_options_storage_new( + struct aws_allocator *allocator, + const struct aws_secure_tunnel_options *options) { + AWS_PRECONDITION(allocator != NULL); + AWS_PRECONDITION(options != NULL); + + if (aws_secure_tunnel_options_validate(options)) { + return NULL; + } + + struct aws_secure_tunnel_options_storage *storage = + aws_mem_calloc(allocator, 1, sizeof(struct aws_secure_tunnel_options_storage)); + + storage->allocator = allocator; + + storage->socket_options = *options->socket_options; + storage->endpoint_host = aws_string_new_from_cursor(allocator, &options->endpoint_host); + if (storage->endpoint_host == NULL) { + goto error; + } + + storage->access_token = aws_string_new_from_cursor(allocator, &options->access_token); + if (storage->access_token == NULL) { + goto error; + } + + /* + * Client token is provided to the secure tunnel service alongside the access token. + * The access token is one-time use unless coupled with a client token. The pair can be used together + * for reconnects. If the user provides one, we will use that. If one is not provided, we will generate + * one for use with this access token to handle reconnecting on disconnections. + */ + if (options->client_token.len > 0) { + storage->client_token = aws_string_new_from_cursor(allocator, &options->client_token); + if (storage->client_token == NULL) { + goto error; + } + } else { + struct aws_uuid uuid; + if (aws_uuid_init(&uuid)) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "Failed to initiate an uuid struct: %s", + aws_error_str(aws_last_error())); + goto error; + } + char uuid_str[AWS_UUID_STR_LEN] = {0}; + struct aws_byte_buf uuid_buf = aws_byte_buf_from_array(uuid_str, sizeof(uuid_str)); + uuid_buf.len = 0; + if (aws_uuid_to_str(&uuid, &uuid_buf)) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Failed to stringify uuid: %s", aws_error_str(aws_last_error())); + goto error; + } + storage->client_token = aws_string_new_from_buf(allocator, &uuid_buf); + } + + storage->local_proxy_mode = options->local_proxy_mode; + + /* acquire reference to everything that's ref-counted */ + storage->bootstrap = aws_client_bootstrap_acquire(options->bootstrap); + + if (options->http_proxy_options != NULL) { + storage->http_proxy_config = + aws_http_proxy_config_new_from_proxy_options(allocator, options->http_proxy_options); + if (storage->http_proxy_config == NULL) { + goto error; + } + + aws_http_proxy_options_init_from_config(&storage->http_proxy_options, storage->http_proxy_config); + } + + if (aws_hash_table_init( + &storage->service_ids, + allocator, + 3, + aws_hash_byte_cursor_ptr, + (aws_hash_callback_eq_fn *)aws_byte_cursor_eq, + NULL, + s_destroy_service_id)) { + goto error; + } + + storage->on_message_received = options->on_message_received; + storage->user_data = options->user_data; + + storage->local_proxy_mode = options->local_proxy_mode; + storage->on_connection_complete = options->on_connection_complete; + storage->on_connection_shutdown = options->on_connection_shutdown; + storage->on_send_data_complete = options->on_send_data_complete; + storage->on_stream_start = options->on_stream_start; + storage->on_stream_reset = options->on_stream_reset; + storage->on_session_reset = options->on_session_reset; + storage->on_stopped = options->on_stopped; + storage->on_termination_complete = options->on_termination_complete; + storage->secure_tunnel_on_termination_user_data = options->secure_tunnel_on_termination_user_data; + + return storage; + +error: + aws_secure_tunnel_options_storage_destroy(storage); + return NULL; +} + +/********************************************************************************************************************* + * Data Tunnel Pair + ********************************************************************************************************************/ + +/* + * Clean up data tunnel pair + */ +void aws_secure_tunnel_data_tunnel_pair_destroy(struct data_tunnel_pair *pair) { + aws_byte_buf_clean_up(&pair->buf); + aws_mem_release(pair->allocator, (void *)pair); +} + +/* + * Create a new data tunnel pair + */ +struct data_tunnel_pair *aws_secure_tunnel_data_tunnel_pair_new( + struct aws_allocator *allocator, + const struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_view) { + AWS_PRECONDITION(allocator != NULL); + AWS_PRECONDITION(secure_tunnel != NULL); + AWS_PRECONDITION(message_view != NULL); + + struct data_tunnel_pair *pair = aws_mem_calloc(allocator, 1, sizeof(struct data_tunnel_pair)); + pair->allocator = allocator; + pair->secure_tunnel = secure_tunnel; + pair->length_prefix_written = false; + if (aws_iot_st_msg_serialize_from_view(&pair->buf, allocator, message_view)) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Failure serializing message"); + goto error; + } + if (pair->buf.len > AWS_IOT_ST_MAX_MESSAGE_SIZE) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Message size greater than AWS_IOT_ST_MAX_MESSAGE_SIZE"); + goto error; + } + + pair->cur = aws_byte_cursor_from_buf(&pair->buf); + + return pair; + +error: + + aws_secure_tunnel_data_tunnel_pair_destroy(pair); + return NULL; +} + +const char *aws_secure_tunnel_operation_type_to_c_string(enum aws_secure_tunnel_operation_type operation_type) { + switch (operation_type) { + case AWS_STOT_NONE: + return "NONE"; + case AWS_STOT_PING: + return "PING"; + case AWS_STOT_MESSAGE: + return "DATA"; + case AWS_STOT_STREAM_RESET: + return "STREAM RESET"; + case AWS_STOT_STREAM_START: + return "STREAM START"; + default: + return "UNKNOWN"; + } +} diff --git a/source/serializer.c b/source/serializer.c index 56518708..565178a3 100644 --- a/source/serializer.c +++ b/source/serializer.c @@ -2,8 +2,14 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0. */ +#include +#include #include +/***************************************************************************************************************** + * ENCODING + *****************************************************************************************************************/ + static int s_iot_st_encode_varint_uint32_t(struct aws_byte_buf *buffer, uint32_t n) { // & 2's comp lement // ~0x7F == b-10000000 @@ -59,27 +65,6 @@ static int s_iot_st_encode_varint_pos(struct aws_byte_buf *buffer, int32_t n) { return AWS_OP_ERR; } -static int s_iot_st_decode_varint_uint32_t(struct aws_byte_cursor *cursor, uint32_t *result) { - int bits = 0; - // Continue while the first bit is one - // 0x80 == b10000000 - uint32_t castPtrValue; - while ((*cursor->ptr & 0x80)) { - castPtrValue = *cursor->ptr; - // Zero out the first bit - // 0x7F == b01111111 - *result += ((castPtrValue & 0x7F) << bits); - AWS_RETURN_ERROR_IF2(aws_byte_cursor_advance(cursor, 1).ptr != NULL, AWS_OP_ERR); - bits += 7; - } - castPtrValue = *cursor->ptr; - AWS_RETURN_ERROR_IF2(aws_byte_cursor_advance(cursor, 1).ptr != NULL, AWS_OP_ERR); - // Zero out the first bit - // 0x7F == b01111111 - *result += ((castPtrValue & 0x7F) << bits); - return AWS_OP_SUCCESS; -} - static int s_iot_st_encode_varint( const uint8_t field_number, const uint8_t wire_type, @@ -91,63 +76,225 @@ static int s_iot_st_encode_varint( return s_iot_st_encode_varint_pos(buffer, value); } -static int s_iot_st_encode_lengthdelim( +static int s_iot_st_encode_byte_range( const uint8_t field_number, const uint8_t wire_type, - struct aws_byte_buf *payload, + const struct aws_byte_cursor *payload, struct aws_byte_buf *buffer) { const uint8_t field_and_wire_type = (field_number << AWS_IOT_ST_FIELD_NUMBER_SHIFT) + wire_type; aws_byte_buf_append_byte_dynamic_secure(buffer, field_and_wire_type); s_iot_st_encode_varint_uint32_t(buffer, (uint32_t)payload->len); - struct aws_byte_cursor temp = aws_byte_cursor_from_array(payload->buffer, payload->len); + struct aws_byte_cursor temp = aws_byte_cursor_from_array(payload->ptr, payload->len); return aws_byte_buf_append_dynamic_secure(buffer, &temp); } static int s_iot_st_encode_stream_id(int32_t data, struct aws_byte_buf *buffer) { - return s_iot_st_encode_varint(AWS_IOT_ST_MESSAGE_STREAM_ID, AWS_IOT_ST_VARINT_WIRE, data, buffer); + return s_iot_st_encode_varint(AWS_SECURE_TUNNEL_FN_STREAM_ID, AWS_SECURE_TUNNEL_PBWT_VARINT, data, buffer); } static int s_iot_st_encode_ignorable(int32_t data, struct aws_byte_buf *buffer) { - return s_iot_st_encode_varint(AWS_IOT_ST_MESSAGE_IGNORABLE, AWS_IOT_ST_VARINT_WIRE, data, buffer); + return s_iot_st_encode_varint(AWS_SECURE_TUNNEL_FN_IGNORABLE, AWS_SECURE_TUNNEL_PBWT_VARINT, data, buffer); } static int s_iot_st_encode_type(int32_t data, struct aws_byte_buf *buffer) { - return s_iot_st_encode_varint(AWS_IOT_ST_MESSAGE_TYPEFIELD, AWS_IOT_ST_VARINT_WIRE, data, buffer); + return s_iot_st_encode_varint(AWS_SECURE_TUNNEL_FN_TYPE, AWS_SECURE_TUNNEL_PBWT_VARINT, data, buffer); } -static int s_iot_st_encode_payload(struct aws_byte_buf *payload, struct aws_byte_buf *buffer) { - return s_iot_st_encode_lengthdelim(AWS_IOT_ST_MESSAGE_PAYLOAD, AWS_IOT_ST_VARINT_LENGTHDELIM_WIRE, payload, buffer); +static int s_iot_st_encode_payload(const struct aws_byte_cursor *payload, struct aws_byte_buf *buffer) { + return s_iot_st_encode_byte_range( + AWS_SECURE_TUNNEL_FN_PAYLOAD, AWS_SECURE_TUNNEL_PBWT_LENGTH_DELIMITED, payload, buffer); +} + +static int s_iot_st_encode_service_id(const struct aws_byte_cursor *service_id, struct aws_byte_buf *buffer) { + return s_iot_st_encode_byte_range( + AWS_SECURE_TUNNEL_FN_SERVICE_ID, AWS_SECURE_TUNNEL_PBWT_LENGTH_DELIMITED, service_id, buffer); +} + +static int s_iot_st_encode_service_ids(const struct aws_byte_cursor *service_id, struct aws_byte_buf *buffer) { + return s_iot_st_encode_byte_range( + AWS_SECURE_TUNNEL_FN_AVAILABLE_SERVICE_IDS, AWS_SECURE_TUNNEL_PBWT_LENGTH_DELIMITED, service_id, buffer); +} + +static int s_iot_st_get_varint_size(size_t value, size_t *encode_size) { + if (value > AWS_IOT_ST_MAXIMUM_VARINT) { + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + if (value < AWS_IOT_ST_MAXIMUM_1_BYTE_VARINT_VALUE) { + *encode_size = 1; + } else if (value < AWS_IOT_ST_MAXIMUM_2_BYTE_VARINT_VALUE) { + *encode_size = 2; + } else if (value < AWS_IOT_ST_MAXIMUM_3_BYTE_VARINT_VALUE) { + *encode_size = 3; + } else { + *encode_size = 4; + } + + return AWS_OP_SUCCESS; } -int aws_iot_st_msg_serialize_from_struct( +static int s_iot_st_compute_message_length( + const struct aws_secure_tunnel_message_view *message, + size_t *message_length) { + size_t local_length = 0; + + /* + * 1 byte type key + * 1 byte type varint + */ + local_length += 2; + + if (message->stream_id != 0) { + /* + * 1 byte stream_id key + * 1-4 byte stream_id varint + */ + size_t stream_id_length = 0; + + if (s_iot_st_get_varint_size((uint32_t)message->stream_id, &stream_id_length)) { + return AWS_OP_ERR; + } + + local_length += (1 + stream_id_length); + } + + if (message->ignorable != 0) { + /* + * 1 byte ignorable key + * 1 byte ignorable varint + */ + local_length += 2; + } + + if (message->payload != NULL && message->payload->len != 0) { + /* + * 1 byte key + * 1-4 byte payload length varint + * n bytes payload.len + */ + size_t payload_length = 0; + if (s_iot_st_get_varint_size((uint32_t)message->payload->len, &payload_length)) { + return AWS_OP_ERR; + } + local_length += (1 + message->payload->len + payload_length); + } + + if (message->service_id != NULL && message->service_id->len != 0) { + /* + * 1 byte key + * 1-4 byte payload length varint + * n bytes service_id.len + */ + size_t service_id_length = 0; + if (s_iot_st_get_varint_size((uint32_t)message->service_id->len, &service_id_length)) { + return AWS_OP_ERR; + } + local_length += (1 + message->service_id->len + service_id_length); + } + + if (message->service_id_2 != NULL && message->service_id_2->len != 0) { + /* + * 1 byte key + * 1-4 byte payload length varint + * n bytes service_id.len + */ + size_t service_id_length_2 = 0; + if (s_iot_st_get_varint_size((uint32_t)message->service_id_2->len, &service_id_length_2)) { + return AWS_OP_ERR; + } + local_length += (1 + message->service_id_2->len + service_id_length_2); + } + + if (message->service_id_3 != NULL && message->service_id_3->len != 0) { + /* + * 1 byte key + * 1-4 byte payload length varint + * n bytes service_id.len + */ + size_t service_id_length_3 = 0; + if (s_iot_st_get_varint_size((uint32_t)message->service_id_3->len, &service_id_length_3)) { + return AWS_OP_ERR; + } + local_length += (1 + message->service_id_3->len + service_id_length_3); + } + + *message_length = local_length; + return AWS_OP_SUCCESS; +} + +int aws_iot_st_msg_serialize_from_view( struct aws_byte_buf *buffer, struct aws_allocator *allocator, - struct aws_iot_st_msg message) { - if (aws_byte_buf_init(buffer, allocator, AWS_IOT_ST_DEFAULT_ALLO + message.payload.len) != AWS_OP_SUCCESS) { + const struct aws_secure_tunnel_message_view *message_view) { + size_t message_total_length = 0; + if (s_iot_st_compute_message_length(message_view, &message_total_length)) { + return AWS_OP_ERR; + } + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: serializing message from view of size %zu.", + (void *)message_view, + message_total_length); + + if (aws_byte_buf_init(buffer, allocator, message_total_length) != AWS_OP_SUCCESS) { return AWS_OP_ERR; } - if (message.type != AWS_IOT_ST_MESSAGE_DEFAULT_TYPE) { - if (s_iot_st_encode_type(message.type, buffer) != AWS_OP_SUCCESS) { + if (message_view->type != AWS_SECURE_TUNNEL_MT_UNKNOWN) { + if (s_iot_st_encode_type(message_view->type, buffer)) { goto cleanup; } + } else { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Message missing type during encoding"); + goto cleanup; } - if (message.stream_id != AWS_IOT_ST_MESSAGE_DEFAULT_STREAM_ID) { - if (s_iot_st_encode_stream_id(message.stream_id, buffer) != AWS_OP_SUCCESS) { + + if (message_view->stream_id != 0) { + if (s_iot_st_encode_stream_id(message_view->stream_id, buffer)) { goto cleanup; } } - if (message.ignorable != AWS_IOT_ST_MESSAGE_DEFAULT_IGNORABLE) { - if (s_iot_st_encode_ignorable(message.ignorable, buffer) != AWS_OP_SUCCESS) { + + if (message_view->ignorable != 0) { + if (s_iot_st_encode_ignorable(message_view->ignorable, buffer)) { goto cleanup; } } - if (message.payload.len != AWS_IOT_ST_MESSAGE_DEFAULT_PAYLOAD) { - if (s_iot_st_encode_payload(&message.payload, buffer) != AWS_OP_SUCCESS) { + + if (message_view->payload != NULL) { + if (s_iot_st_encode_payload(message_view->payload, buffer)) { + goto cleanup; + } + } + + if (message_view->type == AWS_SECURE_TUNNEL_MT_SERVICE_IDS) { + if (message_view->service_id != 0) { + if (s_iot_st_encode_service_ids(message_view->service_id, buffer)) { + goto cleanup; + } + } + if (message_view->service_id_2 != 0) { + if (s_iot_st_encode_service_ids(message_view->service_id_2, buffer)) { + goto cleanup; + } + } + if (message_view->service_id_3 != 0) { + if (s_iot_st_encode_service_ids(message_view->service_id_3, buffer)) { + goto cleanup; + } + } + } else if (message_view->service_id != NULL) { + if (s_iot_st_encode_service_id(message_view->service_id, buffer)) { goto cleanup; } } - AWS_RETURN_ERROR_IF2(buffer->capacity < AWS_IOT_ST_MAX_MESSAGE_SIZE, AWS_ERROR_INVALID_BUFFER_SIZE); + + if (buffer->capacity > AWS_IOT_ST_MAX_MESSAGE_SIZE) { + aws_raise_error(AWS_ERROR_INVALID_BUFFER_SIZE); + goto cleanup; + } + return AWS_OP_SUCCESS; cleanup: @@ -155,21 +302,93 @@ int aws_iot_st_msg_serialize_from_struct( return AWS_OP_ERR; } -static int s_aws_st_decode_lengthdelim(struct aws_byte_cursor *cursor, struct aws_byte_buf *buffer, int length) { - struct aws_byte_cursor temp = aws_byte_cursor_from_array(cursor->ptr, length); - AWS_RETURN_ERROR_IF2(aws_byte_buf_append_dynamic_secure(buffer, &temp) == 0, AWS_OP_ERR); +/***************************************************************************************************************** + * DECODING + *****************************************************************************************************************/ + +static int s_iot_st_decode_varint_uint32_t(struct aws_byte_cursor *cursor, uint32_t *result) { + int bits = 0; + // Continue while the first bit is one + // 0x80 == b10000000 + uint32_t castPtrValue; + while ((*cursor->ptr & 0x80)) { + castPtrValue = *cursor->ptr; + // Zero out the first bit + // 0x7F == b01111111 + *result += ((castPtrValue & 0x7F) << bits); + AWS_RETURN_ERROR_IF2( + aws_byte_cursor_advance(cursor, 1).ptr != NULL, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DECODE_FAILURE); + bits += 7; + } + castPtrValue = *cursor->ptr; + AWS_RETURN_ERROR_IF2( + aws_byte_cursor_advance(cursor, 1).ptr != NULL, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DECODE_FAILURE); + // Zero out the first bit + // 0x7F == b01111111 + *result += ((castPtrValue & 0x7F) << bits); + return AWS_OP_SUCCESS; +} + +int aws_secure_tunnel_deserialize_varint_from_cursor_to_message( + struct aws_byte_cursor *cursor, + uint8_t field_number, + struct aws_secure_tunnel_message_view *message) { + uint32_t result = 0; + + if (s_iot_st_decode_varint_uint32_t(cursor, &result)) { + return AWS_OP_ERR; + } + + switch (field_number) { + case AWS_SECURE_TUNNEL_FN_TYPE: + message->type = result; + break; + case AWS_SECURE_TUNNEL_FN_STREAM_ID: + message->stream_id = result; + break; + case AWS_SECURE_TUNNEL_FN_IGNORABLE: + message->ignorable = result; + break; + default: + AWS_LOGF_WARN( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Unexpected field number in message encountered.", + (void *)message); + /* Unexpected field_number */ + break; + } + return AWS_OP_SUCCESS; } -int aws_iot_st_msg_deserialize_from_cursor( - struct aws_iot_st_msg *message, +int aws_secure_tunnel_deserialize_message_from_cursor( + struct aws_secure_tunnel *secure_tunnel, struct aws_byte_cursor *cursor, - struct aws_allocator *allocator) { + aws_secure_tunnel_on_message_received_fn *on_message_received) { + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: deserializing message from cursor of size %zu.", + (void *)secure_tunnel, + cursor->len); + AWS_RETURN_ERROR_IF2(cursor->len < AWS_IOT_ST_MAX_MESSAGE_SIZE, AWS_ERROR_INVALID_BUFFER_SIZE); uint8_t wire_type; uint8_t field_number; - int length; - int payload_check = 0; + struct aws_secure_tunnel_message_view message_view; + AWS_ZERO_STRUCT(message_view); + + struct aws_byte_cursor payload_cur; + AWS_ZERO_STRUCT(payload_cur); + + int service_ids_set = 0; + struct aws_byte_cursor service_id_1_cur; + struct aws_byte_cursor service_id_2_cur; + struct aws_byte_cursor service_id_3_cur; + AWS_ZERO_STRUCT(service_id_1_cur); + AWS_ZERO_STRUCT(service_id_2_cur); + AWS_ZERO_STRUCT(service_id_3_cur); + while ((aws_byte_cursor_is_valid(cursor)) && (cursor->len > 0)) { // wire_type is only the first 3 bits, Zeroing out the first 5 // 0x07 == 00000111 @@ -177,46 +396,106 @@ int aws_iot_st_msg_deserialize_from_cursor( field_number = (*cursor->ptr) >> 3; aws_byte_cursor_advance(cursor, 1); - if (field_number == AWS_IOT_ST_STREAM_ID_FIELD_NUMBER && wire_type == AWS_IOT_ST_VARINT_WIRE) { - uint32_t res = 0; - if (s_iot_st_decode_varint_uint32_t(cursor, &res) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - message->stream_id = res; - } else if (field_number == AWS_IOT_ST_IGNORABLE_FIELD_NUMBER && wire_type == AWS_IOT_ST_VARINT_WIRE) { - uint32_t res = 0; - if (s_iot_st_decode_varint_uint32_t(cursor, &res) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - message->ignorable = res; - } else if (field_number == AWS_IOT_ST_TYPE_FIELD_NUMBER && wire_type == AWS_IOT_ST_VARINT_WIRE) { - uint32_t res = 0; - if (s_iot_st_decode_varint_uint32_t(cursor, &res) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - message->type = res; - } else if (field_number == AWS_IOT_ST_PAYLOAD_FIELD_NUMBER && wire_type == AWS_IOT_ST_VARINT_LENGTHDELIM_WIRE) { - uint32_t res = 0; - if (s_iot_st_decode_varint_uint32_t(cursor, &res) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } - length = res; - if (aws_byte_buf_init(&message->payload, allocator, length) != AWS_OP_SUCCESS) { - return AWS_OP_ERR; - } + /* ignorable defaults to false unless set to true in the incoming message*/ + message_view.ignorable = false; - if (s_aws_st_decode_lengthdelim(cursor, &message->payload, length) != AWS_OP_SUCCESS) { - goto cleanup; - } - aws_byte_cursor_advance(cursor, length); - payload_check = 1; + switch (wire_type) { + case AWS_SECURE_TUNNEL_PBWT_VARINT: + if (aws_secure_tunnel_deserialize_varint_from_cursor_to_message(cursor, field_number, &message_view)) { + goto error; + } + break; + + case AWS_SECURE_TUNNEL_PBWT_LENGTH_DELIMITED: { + + uint32_t length = 0; + if (s_iot_st_decode_varint_uint32_t(cursor, &length)) { + goto error; + } + + switch (field_number) { + case AWS_SECURE_TUNNEL_FN_PAYLOAD: + payload_cur = aws_byte_cursor_advance(cursor, length); + message_view.payload = &payload_cur; + break; + + case AWS_SECURE_TUNNEL_FN_SERVICE_ID: + service_id_1_cur = aws_byte_cursor_advance(cursor, length); + message_view.service_id = &service_id_1_cur; + break; + + case AWS_SECURE_TUNNEL_FN_AVAILABLE_SERVICE_IDS: + switch (service_ids_set) { + case 0: + service_id_1_cur = aws_byte_cursor_advance(cursor, length); + message_view.service_id = &service_id_1_cur; + break; + case 1: + service_id_2_cur = aws_byte_cursor_advance(cursor, length); + message_view.service_id_2 = &service_id_2_cur; + break; + case 2: + service_id_3_cur = aws_byte_cursor_advance(cursor, length); + message_view.service_id_3 = &service_id_3_cur; + break; + default: + goto error; + break; + } + service_ids_set++; + break; + } + } break; + + /* These wire types are unexpected and should result in an error log */ + case AWS_SECURE_TUNNEL_PBWT_64_BIT: + case AWS_SECURE_TUNNEL_PBWT_START_GROUP: + case AWS_SECURE_TUNNEL_PBWT_END_GROUP: + case AWS_SECURE_TUNNEL_PBWT_32_BIT: + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Unexpected wire type in message encountered.", + (void *)secure_tunnel); + goto error; + break; } } - if (payload_check == 0) { - AWS_ZERO_STRUCT(message->payload); - } + + on_message_received(secure_tunnel, &message_view); + return AWS_OP_SUCCESS; -cleanup: - aws_byte_buf_clean_up(&message->payload); - return AWS_OP_ERR; + +error: + return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DECODE_FAILURE; +} + +const char *aws_secure_tunnel_message_type_to_c_string(enum aws_secure_tunnel_message_type message_type) { + switch (message_type) { + case AWS_SECURE_TUNNEL_MT_UNKNOWN: + return "ST_MT_UNKNOWN"; + + case AWS_SECURE_TUNNEL_MT_DATA: + return "DATA"; + + case AWS_SECURE_TUNNEL_MT_STREAM_START: + return "STREAM START"; + + case AWS_SECURE_TUNNEL_MT_STREAM_RESET: + return "STREAM RESET"; + + case AWS_SECURE_TUNNEL_MT_SESSION_RESET: + return "SESSION RESET"; + + case AWS_SECURE_TUNNEL_MT_SERVICE_IDS: + return "SERVICE IDS"; + + case AWS_SECURE_TUNNEL_MT_CONNECTION_START: + return "CONNECTION START"; + + case AWS_SECURE_TUNNEL_MT_CONNECTION_RESET: + return "CONNECTION RESET"; + + default: + return "UNKNOWN"; + } } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 5d1053c3..19bf90aa 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,7 +3,7 @@ include(AwsTestHarness) enable_testing() file(GLOB TEST_HDRS "mqtt_mock_structs.h") -set(TEST_SRC iotdevice_tests.c metrics_tests.c secure_tunneling_tests.c) +set(TEST_SRC iotdevice_tests.c metrics_tests.c secure_tunneling_tests.c secure_tunnel_tests.c) file(GLOB TESTS ${TEST_HDRS} ${TEST_SRC}) add_test_case(library_init) @@ -19,16 +19,15 @@ if (UNIX AND NOT APPLE) add_test_case(devicedefender_publish_failure_callback_invoked) endif() -add_net_test_case(secure_tunneling_handle_stream_start_test) -add_net_test_case(secure_tunneling_handle_data_receive_test) -add_net_test_case(secure_tunneling_handle_stream_reset_test) -add_net_test_case(secure_tunneling_handle_session_reset_test) -add_net_test_case(secure_tunneling_handle_session_reset_no_stream_test) -add_net_test_case(secure_tunneling_init_websocket_options_test) -add_net_test_case(secure_tunneling_handle_send_data) -add_net_test_case(secure_tunneling_handle_send_data_stream_start) -add_net_test_case(secure_tunneling_handle_send_data_stream_reset) -add_net_test_case(secure_tunneling_handle_send_data_public) +add_test_case(secure_tunneling_functionality_connect_test) +add_test_case(secure_tunneling_functionality_client_token_test) +add_test_case(secure_tunneling_fail_and_retry_connection_test) +add_test_case(secure_tunneling_store_service_ids_test) +add_test_case(secure_tunneling_receive_stream_start_test) +add_test_case(secure_tunneling_rejected_service_id_stream_start_test) +add_test_case(secure_tunneling_close_stream_on_stream_reset_test) +add_test_case(secure_tunneling_session_reset_test) +add_test_case(secure_tunneling_serializer_data_message_test) generate_test_driver(${PROJECT_NAME}-tests) @@ -41,15 +40,6 @@ aws_add_sanitizers(${TEST_DD_CLIENT_BINARY_NAME} ${${PROJECT_NAME}_SANITIZERS}) target_compile_definitions(${TEST_DD_CLIENT_BINARY_NAME} PRIVATE AWS_UNSTABLE_TESTING_API=1) target_include_directories(${TEST_DD_CLIENT_BINARY_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -# Secure Tunneling test client -set(TEST_ST_CLIENT_BINARY_NAME ${PROJECT_NAME}-secure_tunneling-client) -add_executable(${TEST_ST_CLIENT_BINARY_NAME} "aws_iot_secure_tunneling_client_test.c") -target_link_libraries(${TEST_ST_CLIENT_BINARY_NAME} PRIVATE ${PROJECT_NAME}) -aws_set_common_properties(${TEST_ST_CLIENT_BINARY_NAME} NO_WEXTRA NO_PEDANTIC) -aws_add_sanitizers(${TEST_ST_CLIENT_BINARY_NAME} ${${PROJECT_NAME}_SANITIZERS}) -target_compile_definitions(${TEST_ST_CLIENT_BINARY_NAME} PRIVATE AWS_UNSTABLE_TESTING_API=1) -target_include_directories(${TEST_ST_CLIENT_BINARY_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) - if ($ENV{PROTOBUF_TEST}) add_subdirectory(tests_protobuf) endif () diff --git a/tests/aws_iot_secure_tunneling_client_test.c b/tests/aws_iot_secure_tunneling_client_test.c deleted file mode 100644 index 0e607d9b..00000000 --- a/tests/aws_iot_secure_tunneling_client_test.c +++ /dev/null @@ -1,199 +0,0 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define UNUSED(x) (void)(x) - -static struct aws_mutex mutex = AWS_MUTEX_INIT; -static struct aws_condition_variable condition_variable = AWS_CONDITION_VARIABLE_INIT; - -static int on_send_data_complete_error_code = 0; - -static void s_on_send_data_complete(int error_code, void *user_data) { - UNUSED(user_data); - on_send_data_complete_error_code = error_code; -} - -static void s_on_connection_complete(void *user_data) { - UNUSED(user_data); - aws_mutex_lock(&mutex); - aws_condition_variable_notify_one(&condition_variable); - aws_mutex_unlock(&mutex); -} - -static void s_on_connection_shutdown(void *user_data) { - UNUSED(user_data); -} - -static void s_on_data_receive(const struct aws_byte_buf *data, void *user_data) { - AWS_LOGF_INFO(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Client received data:"); - - struct aws_allocator *allocator = (struct aws_allocator *)user_data; - - struct aws_byte_cursor data_cursor = aws_byte_cursor_from_buf(data); - struct aws_byte_buf data_to_print; - aws_byte_buf_init(&data_to_print, allocator, data->len + 1); /* +1 for null terminator */ - aws_byte_buf_append(&data_to_print, &data_cursor); - aws_byte_buf_append_null_terminator(&data_to_print); - AWS_LOGF_INFO(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "%s", (char *)data_to_print.buffer); - - aws_byte_buf_clean_up(&data_to_print); -} - -static void s_on_stream_start(void *user_data) { - UNUSED(user_data); - AWS_LOGF_INFO(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Client received StreamStart."); -} - -static void s_on_stream_reset(void *user_data) { - UNUSED(user_data); - AWS_LOGF_INFO(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Client received StreamReset."); -} - -static void s_on_session_reset(void *user_data) { - UNUSED(user_data); - AWS_LOGF_INFO(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Client received SessionReset."); -} - -enum aws_secure_tunneling_local_proxy_mode s_local_proxy_mode_from_c_str(const char *local_proxy_mode) { - if (strcmp(local_proxy_mode, "src") == 0) { - return AWS_SECURE_TUNNELING_SOURCE_MODE; - } - return AWS_SECURE_TUNNELING_DESTINATION_MODE; -} - -static void s_init_secure_tunneling_connection_config( - struct aws_allocator *allocator, - struct aws_client_bootstrap *bootstrap, - struct aws_socket_options *socket_options, - const char *access_token, - enum aws_secure_tunneling_local_proxy_mode local_proxy_mode, - const char *endpoint, - const char *root_ca, - struct aws_secure_tunnel_options *config) { - - AWS_ZERO_STRUCT(*config); - config->allocator = allocator; - config->bootstrap = bootstrap; - config->socket_options = socket_options; - - config->access_token = aws_byte_cursor_from_c_str(access_token); - config->local_proxy_mode = local_proxy_mode; - config->endpoint_host = aws_byte_cursor_from_c_str(endpoint); - config->root_ca = root_ca; - - config->on_connection_complete = s_on_connection_complete; - config->on_connection_shutdown = s_on_connection_shutdown; - config->on_send_data_complete = s_on_send_data_complete; - config->on_data_receive = s_on_data_receive; - config->on_stream_start = s_on_stream_start; - config->on_stream_reset = s_on_stream_reset; - config->on_session_reset = s_on_session_reset; - - config->user_data = allocator; -} - -int main(int argc, char **argv) { - if (argc < 5) { - printf( - "3 args required, only %d passed. Usage:\n" - "aws-c-iot-secure_tunneling-client [endpoint] [src|dest] [root_ca] [access_token]\n", - argc - 1); - return 1; - } - const char *endpoint = argv[1]; - enum aws_secure_tunneling_local_proxy_mode local_proxy_mode = s_local_proxy_mode_from_c_str(argv[2]); - const char *root_ca = argv[3]; - const char *access_token = argv[4]; - - struct aws_allocator *allocator = aws_mem_tracer_new(aws_default_allocator(), NULL, AWS_MEMTRACE_BYTES, 0); - - aws_iotdevice_library_init(allocator); - - struct aws_logger_standard_options logger_options = { - .level = AWS_LL_TRACE, - .file = stdout, - }; - struct aws_logger logger; - aws_logger_init_standard(&logger, allocator, &logger_options); - aws_logger_set(&logger); - - struct aws_event_loop_group *elg = aws_event_loop_group_new_default(allocator, 1, NULL); - struct aws_host_resolver_default_options host_resolver_default_options; - AWS_ZERO_STRUCT(host_resolver_default_options); - host_resolver_default_options.max_entries = 8; - host_resolver_default_options.el_group = elg; - host_resolver_default_options.shutdown_options = NULL; - host_resolver_default_options.system_clock_override_fn = NULL; - struct aws_host_resolver *resolver = aws_host_resolver_new_default(allocator, &host_resolver_default_options); - struct aws_client_bootstrap_options bootstrap_options = { - .event_loop_group = elg, - .host_resolver = resolver, - }; - struct aws_client_bootstrap *bootstrap = aws_client_bootstrap_new(allocator, &bootstrap_options); - - struct aws_socket_options socket_options; - AWS_ZERO_STRUCT(socket_options); - socket_options.connect_timeout_ms = 3000; - socket_options.type = AWS_SOCKET_STREAM; - socket_options.domain = AWS_SOCKET_IPV4; - - /* setup secure tunneling connection config */ - struct aws_secure_tunnel_options config; - s_init_secure_tunneling_connection_config( - allocator, bootstrap, &socket_options, access_token, local_proxy_mode, endpoint, root_ca, &config); - - /* Create a secure tunnel object and connect */ - struct aws_secure_tunnel *secure_tunnel = aws_secure_tunnel_new(&config); - aws_secure_tunnel_connect(secure_tunnel); - - /* wait here until the connection is done */ - aws_mutex_lock(&mutex); - ASSERT_SUCCESS(aws_condition_variable_wait(&condition_variable, &mutex)); - aws_mutex_unlock(&mutex); - - if (local_proxy_mode == AWS_SECURE_TUNNELING_SOURCE_MODE) { - AWS_RETURN_ERROR_IF2(aws_secure_tunnel_stream_start(secure_tunnel) == AWS_OP_SUCCESS, AWS_OP_ERR); - - int cLen = 500000; - char *payload = malloc(cLen + 1); - memset(payload, 'a', cLen); - payload[cLen] = 0; - struct aws_byte_cursor cur = aws_byte_cursor_from_c_str(payload); - AWS_RETURN_ERROR_IF2(aws_secure_tunnel_send_data(secure_tunnel, &cur) == AWS_OP_SUCCESS, AWS_OP_ERR); - - AWS_RETURN_ERROR_IF2(aws_secure_tunnel_stream_reset(secure_tunnel) == AWS_OP_SUCCESS, AWS_OP_ERR); - ASSERT_SUCCESS(aws_condition_variable_wait(&condition_variable, &mutex)); - } else if (local_proxy_mode == AWS_SECURE_TUNNELING_DESTINATION_MODE) { - /* Wait a little for data to show up */ - aws_thread_current_sleep((uint64_t)60 * 60 * 1000000000); - } - aws_thread_current_sleep((uint64_t)60 * 60 * 1000000000); - - /* clean up */ - aws_secure_tunnel_close(secure_tunnel); - aws_secure_tunnel_release(secure_tunnel); - - aws_client_bootstrap_release(bootstrap); - aws_host_resolver_release(resolver); - aws_event_loop_group_release(elg); - aws_logger_clean_up(&logger); - aws_iotdevice_library_clean_up(); - - ASSERT_UINT_EQUALS(0, aws_mem_tracer_count(allocator)); - allocator = aws_mem_tracer_destroy(allocator); - ASSERT_NOT_NULL(allocator); - - return AWS_OP_SUCCESS; -} diff --git a/tests/secure_tunnel_tests.c b/tests/secure_tunnel_tests.c new file mode 100644 index 00000000..c5678ae3 --- /dev/null +++ b/tests/secure_tunnel_tests.c @@ -0,0 +1,1123 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define PAYLOAD_BYTE_LENGTH_PREFIX 2 +AWS_STATIC_STRING_FROM_LITERAL(s_access_token, "IAmAnAccessToken"); +AWS_STATIC_STRING_FROM_LITERAL(s_client_token, "IAmAClientToken"); +AWS_STATIC_STRING_FROM_LITERAL(s_endpoint_host, "IAmAnEndpointHost"); +AWS_STATIC_STRING_FROM_LITERAL(s_service_id_1, "ServiceId1"); +AWS_STATIC_STRING_FROM_LITERAL(s_service_id_2, "ServiceId2"); +AWS_STATIC_STRING_FROM_LITERAL(s_service_id_3, "ServiceId3"); +AWS_STATIC_STRING_FROM_LITERAL(s_service_id_wrong, "ServiceIdWrong"); +AWS_STATIC_STRING_FROM_LITERAL(s_payload_text, "IAmABunchOfPayloadText"); + +#ifdef _WIN32 +# define LOCAL_SOCK_TEST_PATTERN "\\\\.\\pipe\\testsock%llu" +#else +# define LOCAL_SOCK_TEST_PATTERN "testsock%llu.sock" +#endif + +struct aws_secure_tunnel_mock_websocket_vtable { + aws_websocket_on_connection_setup_fn *on_connection_setup_fn; + aws_websocket_on_connection_shutdown_fn *on_connection_shutdown_fn; + aws_websocket_on_incoming_frame_begin_fn *on_incoming_frame_begin_fn; + aws_websocket_on_incoming_frame_payload_fn *on_incoming_frame_payload_fn; + aws_websocket_on_incoming_frame_complete_fn *on_incoming_frame_complete_fn; +}; + +struct aws_secure_tunnel_mock_test_fixture_options { + struct aws_secure_tunnel_options *secure_tunnel_options; + struct aws_secure_tunnel_mock_websocket_vtable *websocket_function_table; + + void *mock_server_user_data; +}; + +struct secure_tunnel_test_options { + struct aws_secure_tunnel_options secure_tunnel_options; + struct aws_secure_tunnel_mock_websocket_vtable websocket_function_table; +}; + +static void s_secure_tunnel_test_init_default_options(struct secure_tunnel_test_options *test_options) { + struct aws_secure_tunnel_options local_secure_tunnel_options = { + .endpoint_host = aws_byte_cursor_from_string(s_endpoint_host), + .access_token = aws_byte_cursor_from_string(s_access_token), + .local_proxy_mode = AWS_SECURE_TUNNELING_DESTINATION_MODE, + }; + test_options->secure_tunnel_options = local_secure_tunnel_options; +} + +typedef int(aws_secure_tunnel_mock_test_fixture_header_check_fn)( + const struct aws_http_headers *request_headers, + void *user_data); + +struct aws_secure_tunnel_mock_test_fixture { + struct aws_allocator *allocator; + + struct aws_event_loop_group *secure_tunnel_elg; + struct aws_host_resolver *host_resolver; + struct aws_client_bootstrap *secure_tunnel_bootstrap; + struct aws_socket_endpoint endpoint; + struct aws_socket_options socket_options; + + struct aws_secure_tunnel_mock_websocket_vtable *websocket_function_table; + void *mock_server_user_data; + + struct aws_secure_tunnel *secure_tunnel; + struct aws_secure_tunnel_vtable secure_tunnel_vtable; + + aws_secure_tunnel_mock_test_fixture_header_check_fn *header_check; + + struct aws_mutex lock; + struct aws_condition_variable signal; + bool listener_destroyed; + bool secure_tunnel_terminated; + bool secure_tunnel_connected_succesfully; + bool secure_tunnel_connection_shutdown; + bool secure_tunnel_connection_failed; + bool secure_tunnel_stream_started; + bool secure_tunnel_bad_stream_request; + bool secure_tunnel_stream_reset_received; + bool secure_tunnel_session_reset_received; + + struct aws_byte_buf last_message_payload_buf; + + int secure_tunnel_message_received_count; + int secure_tunnel_stream_started_count; + int secure_tunnel_stream_started_count_target; + int secure_tunnel_message_count_target; +}; + +/***************************************************************************************************************** + * SECURE TUNNEL CALLBACKS + *****************************************************************************************************************/ + +static void s_on_test_secure_tunnel_connection_complete( + const struct aws_secure_tunnel_connection_view *connection_view, + int error_code, + void *user_data) { + (void)connection_view; + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + + aws_mutex_lock(&test_fixture->lock); + if (error_code == 0) { + test_fixture->secure_tunnel_connected_succesfully = true; + } else { + test_fixture->secure_tunnel_connection_failed = true; + } + aws_mutex_unlock(&test_fixture->lock); + aws_condition_variable_notify_all(&test_fixture->signal); +} + +static void s_on_test_secure_tunnel_connection_shutdown(int error_code, void *user_data) { + (void)error_code; + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + + aws_mutex_lock(&test_fixture->lock); + test_fixture->secure_tunnel_connection_shutdown = true; + aws_mutex_unlock(&test_fixture->lock); + aws_condition_variable_notify_all(&test_fixture->signal); +} + +static void s_on_test_secure_tunnel_message_received( + const struct aws_secure_tunnel_message_view *message, + void *user_data) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + aws_mutex_lock(&test_fixture->lock); + test_fixture->secure_tunnel_message_received_count++; + aws_byte_buf_clean_up(&test_fixture->last_message_payload_buf); + aws_byte_buf_init(&test_fixture->last_message_payload_buf, test_fixture->allocator, message->payload->len); + struct aws_byte_cursor payload_cur = { + .ptr = message->payload->ptr, + .len = message->payload->len, + }; + aws_byte_buf_write_from_whole_cursor(&test_fixture->last_message_payload_buf, payload_cur); + aws_mutex_unlock(&test_fixture->lock); + aws_condition_variable_notify_all(&test_fixture->signal); +} + +static void s_on_test_secure_tunnel_send_data_complete(int error_code, void *user_data) { + (void)error_code; + (void)user_data; +} + +static void s_on_test_secure_tunnel_on_session_reset(void *user_data) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + + aws_mutex_lock(&test_fixture->lock); + test_fixture->secure_tunnel_session_reset_received = true; + aws_mutex_unlock(&test_fixture->lock); + aws_condition_variable_notify_all(&test_fixture->signal); +} + +static void s_on_test_secure_tunnel_on_stopped(void *user_data) { + (void)user_data; +} + +static void s_on_test_secure_tunnel_termination(void *user_data) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + + aws_mutex_lock(&test_fixture->lock); + test_fixture->secure_tunnel_terminated = true; + aws_mutex_unlock(&test_fixture->lock); + aws_condition_variable_notify_all(&test_fixture->signal); +} + +static void s_on_test_secure_tunnel_on_stream_reset( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data) { + (void)message; + (void)error_code; + + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + + aws_mutex_lock(&test_fixture->lock); + test_fixture->secure_tunnel_stream_reset_received = true; + aws_mutex_unlock(&test_fixture->lock); + aws_condition_variable_notify_all(&test_fixture->signal); +} + +static void s_on_test_secure_tunnel_on_stream_start( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data) { + (void)message; + + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + + aws_mutex_lock(&test_fixture->lock); + if (error_code == AWS_OP_SUCCESS) { + test_fixture->secure_tunnel_stream_started = true; + test_fixture->secure_tunnel_stream_started_count++; + } else { + test_fixture->secure_tunnel_bad_stream_request = true; + } + aws_mutex_unlock(&test_fixture->lock); + aws_condition_variable_notify_all(&test_fixture->signal); +} + +/***************************************************************************************************************** + * SECURE TUNNEL STATUS CHECKS + *****************************************************************************************************************/ + +static bool s_has_secure_tunnel_terminated(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_terminated; +} + +static void s_wait_for_secure_tunnel_terminated(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_terminated, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_connected_succesfully(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_connected_succesfully; +} + +static void s_wait_for_connected_successfully(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_connected_succesfully, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_connection_shutdown(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_connection_shutdown; +} + +static void s_wait_for_connection_shutdown(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_connection_shutdown, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_stream_started(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_stream_started; +} + +static void s_wait_for_stream_started(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_stream_started, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_bad_stream_request(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_bad_stream_request; +} + +static void s_wait_for_bad_stream_request(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_bad_stream_request, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_stream_reset_received(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_stream_reset_received; +} + +static void s_wait_for_stream_reset_received(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_stream_reset_received, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_n_stream_started(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_stream_started_count == test_fixture->secure_tunnel_stream_started_count_target; +} + +static void s_wait_for_n_stream_started(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_n_stream_started, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_session_reset_received(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_session_reset_received; +} + +static void s_wait_for_session_reset_received(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_session_reset_received, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_n_messages_received(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_stream_started_count == test_fixture->secure_tunnel_message_count_target; +} + +static void s_wait_for_n_messages_received(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_n_messages_received, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +/***************************************************************************************************************** + * WEBSOCKET MOCK FUNCTIONS + *****************************************************************************************************************/ + +/* Serializes message view and sends as Websocket */ +void aws_secure_tunnel_send_mock_message( + struct aws_secure_tunnel_mock_test_fixture *test_fixture, + const struct aws_secure_tunnel_message_view *message_view) { + + struct aws_byte_buf data_buf; + struct aws_byte_cursor data_cur; + struct aws_byte_buf out_buf; + aws_iot_st_msg_serialize_from_view(&data_buf, test_fixture->allocator, message_view); + data_cur = aws_byte_cursor_from_buf(&data_buf); + aws_byte_buf_init(&out_buf, test_fixture->allocator, data_cur.len + PAYLOAD_BYTE_LENGTH_PREFIX); + aws_byte_buf_write_be16(&out_buf, (int16_t)data_buf.len); + aws_byte_buf_write_to_capacity(&out_buf, &data_cur); + data_cur = aws_byte_cursor_from_buf(&out_buf); + test_fixture->websocket_function_table->on_incoming_frame_payload_fn( + NULL, NULL, data_cur, test_fixture->secure_tunnel); + + aws_byte_buf_clean_up(&out_buf); + aws_byte_buf_clean_up(&data_buf); +} + +int aws_websocket_client_connect_mock_fn(const struct aws_websocket_client_connection_options *options) { + struct aws_secure_tunnel *secure_tunnel = options->user_data; + struct aws_secure_tunnel_mock_test_fixture *test_fixture = secure_tunnel->config->user_data; + + if (!options->handshake_request) { + AWS_LOGF_ERROR( + AWS_LS_HTTP_WEBSOCKET_SETUP, + "id=static: Invalid connection options, missing required request for websocket client handshake."); + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + const struct aws_http_headers *request_headers = aws_http_message_get_headers(options->handshake_request); + if (test_fixture->header_check) { + ASSERT_SUCCESS(test_fixture->header_check(request_headers, test_fixture)); + } + + test_fixture->websocket_function_table->on_connection_setup_fn = options->on_connection_setup; + test_fixture->websocket_function_table->on_connection_shutdown_fn = options->on_connection_shutdown; + test_fixture->websocket_function_table->on_incoming_frame_begin_fn = options->on_incoming_frame_begin; + test_fixture->websocket_function_table->on_incoming_frame_payload_fn = options->on_incoming_frame_payload; + test_fixture->websocket_function_table->on_incoming_frame_complete_fn = options->on_incoming_frame_complete; + + void *pointer = test_fixture; + struct aws_websocket_on_connection_setup_data websocket_setup = {.error_code = AWS_ERROR_SUCCESS, + .websocket = pointer}; + + (test_fixture->websocket_function_table->on_connection_setup_fn)(&websocket_setup, secure_tunnel); + + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_byte_cursor service_2 = aws_byte_cursor_from_string(s_service_id_2); + struct aws_byte_cursor service_3 = aws_byte_cursor_from_string(s_service_id_3); + + struct aws_secure_tunnel_message_view service_ids_message = { + .type = AWS_SECURE_TUNNEL_MT_SERVICE_IDS, + .service_id = &service_1, + .service_id_2 = &service_2, + .service_id_3 = &service_3, + }; + + aws_secure_tunnel_send_mock_message(test_fixture, &service_ids_message); + + return AWS_OP_SUCCESS; +} + +int aws_websocket_send_frame_mock_fn( + struct aws_websocket *websocket, + const struct aws_websocket_send_frame_options *options) { + (void)websocket; + (void)options; + return AWS_OP_SUCCESS; +} + +void aws_websocket_release_mock_fn(struct aws_websocket *websocket) { + (void)websocket; +} + +void aws_websocket_close_mock_fn(struct aws_websocket *websocket, bool free_scarce_resources_immediately) { + (void)free_scarce_resources_immediately; + void *pointer = websocket; + struct aws_secure_tunnel_mock_test_fixture *test_fixture = pointer; + test_fixture->websocket_function_table->on_connection_shutdown_fn(websocket, 0, test_fixture->secure_tunnel); +} + +/***************************************************************************************************************** + * TEST FIXTURE + *****************************************************************************************************************/ + +int aws_secure_tunnel_mock_test_fixture_init( + struct aws_secure_tunnel_mock_test_fixture *test_fixture, + struct aws_allocator *allocator, + struct aws_secure_tunnel_mock_test_fixture_options *options) { + + AWS_ZERO_STRUCT(*test_fixture); + test_fixture->allocator = allocator; + + aws_mutex_init(&test_fixture->lock); + aws_condition_variable_init(&test_fixture->signal); + + test_fixture->websocket_function_table = options->websocket_function_table; + test_fixture->mock_server_user_data = options->mock_server_user_data; + + struct aws_socket_options socket_options = { + .connect_timeout_ms = 1000, + .domain = AWS_SOCKET_LOCAL, + }; + + test_fixture->socket_options = socket_options; + + test_fixture->secure_tunnel_elg = aws_event_loop_group_new_default(allocator, 4, NULL); + struct aws_host_resolver_default_options resolver_options = { + .el_group = test_fixture->secure_tunnel_elg, + .max_entries = 1, + }; + test_fixture->host_resolver = aws_host_resolver_new_default(allocator, &resolver_options); + + struct aws_client_bootstrap_options bootstrap_options = { + .event_loop_group = test_fixture->secure_tunnel_elg, + .user_data = test_fixture, + .host_resolver = test_fixture->host_resolver, + }; + + test_fixture->secure_tunnel_bootstrap = aws_client_bootstrap_new(allocator, &bootstrap_options); + + uint64_t timestamp = 0; + ASSERT_SUCCESS(aws_sys_clock_get_ticks(×tamp)); + + snprintf( + test_fixture->endpoint.address, + sizeof(test_fixture->endpoint.address), + LOCAL_SOCK_TEST_PATTERN, + (long long unsigned)timestamp); + + options->secure_tunnel_options->endpoint_host = aws_byte_cursor_from_c_str(test_fixture->endpoint.address); + options->secure_tunnel_options->bootstrap = test_fixture->secure_tunnel_bootstrap; + options->secure_tunnel_options->socket_options = &test_fixture->socket_options; + options->secure_tunnel_options->access_token = aws_byte_cursor_from_string(s_access_token); + options->secure_tunnel_options->user_data = test_fixture; + + /* Secure Tunnel Callbacks */ + options->secure_tunnel_options->on_connection_complete = s_on_test_secure_tunnel_connection_complete; + options->secure_tunnel_options->on_connection_shutdown = s_on_test_secure_tunnel_connection_shutdown; + options->secure_tunnel_options->on_message_received = s_on_test_secure_tunnel_message_received; + options->secure_tunnel_options->on_send_data_complete = s_on_test_secure_tunnel_send_data_complete; + options->secure_tunnel_options->on_session_reset = s_on_test_secure_tunnel_on_session_reset; + options->secure_tunnel_options->on_stopped = s_on_test_secure_tunnel_on_stopped; + options->secure_tunnel_options->on_stream_reset = s_on_test_secure_tunnel_on_stream_reset; + options->secure_tunnel_options->on_stream_start = s_on_test_secure_tunnel_on_stream_start; + options->secure_tunnel_options->on_termination_complete = s_on_test_secure_tunnel_termination; + options->secure_tunnel_options->secure_tunnel_on_termination_user_data = test_fixture; + + test_fixture->secure_tunnel = aws_secure_tunnel_new(allocator, options->secure_tunnel_options); + + /* Replace Secure Tunnel's vtable functions */ + test_fixture->secure_tunnel_vtable = *aws_secure_tunnel_get_default_vtable(); + test_fixture->secure_tunnel_vtable.aws_websocket_client_connect_fn = aws_websocket_client_connect_mock_fn; + test_fixture->secure_tunnel_vtable.aws_websocket_send_frame_fn = aws_websocket_send_frame_mock_fn; + test_fixture->secure_tunnel_vtable.aws_websocket_release_fn = aws_websocket_release_mock_fn; + test_fixture->secure_tunnel_vtable.aws_websocket_close_fn = aws_websocket_close_mock_fn; + test_fixture->secure_tunnel_vtable.vtable_user_data = test_fixture; + + aws_secure_tunnel_set_vtable(test_fixture->secure_tunnel, &test_fixture->secure_tunnel_vtable); + + return AWS_OP_SUCCESS; +} + +void aws_secure_tunnel_mock_test_fixture_clean_up(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + s_wait_for_secure_tunnel_terminated(test_fixture); + aws_client_bootstrap_release(test_fixture->secure_tunnel_bootstrap); + aws_host_resolver_release(test_fixture->host_resolver); + + aws_event_loop_group_release(test_fixture->secure_tunnel_elg); + + aws_byte_buf_clean_up(&test_fixture->last_message_payload_buf); + aws_mutex_clean_up(&test_fixture->lock); + aws_condition_variable_clean_up(&test_fixture->signal); +} + +/********************************************************************************************************************* + * TESTS + ********************************************************************************************************************/ + +/* [Func-UC1] */ +int secure_tunneling_access_token_check(const struct aws_http_headers *request_headers, void *user_data) { + (void)user_data; + struct aws_byte_cursor access_token_cur; + if (aws_http_headers_get(request_headers, aws_byte_cursor_from_c_str("access-token"), &access_token_cur)) { + AWS_LOGF_ERROR( + AWS_LS_HTTP_WEBSOCKET_SETUP, + "id=static: Websocket handshake request is missing required 'access-token' header"); + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + ASSERT_CURSOR_VALUE_STRING_EQUALS(access_token_cur, s_access_token); + return AWS_ERROR_SUCCESS; +} + +static int s_secure_tunneling_functionality_connect_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + test_fixture.header_check = secure_tunneling_access_token_check; + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_functionality_connect_test, s_secure_tunneling_functionality_connect_test_fn) + +/* [Func-UC2] */ +int secure_tunneling_client_token_check(const struct aws_http_headers *request_headers, void *user_data) { + (void)user_data; + struct aws_byte_cursor client_token_cur; + if (aws_http_headers_get(request_headers, aws_byte_cursor_from_c_str("client-token"), &client_token_cur)) { + AWS_LOGF_ERROR( + AWS_LS_HTTP_WEBSOCKET_SETUP, + "id=static: Websocket handshake request is missing required 'client-token' header"); + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + ASSERT_CURSOR_VALUE_STRING_EQUALS(client_token_cur, s_client_token); + return AWS_ERROR_SUCCESS; +} + +static int s_secure_tunneling_functionality_client_token_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + test_options.secure_tunnel_options.client_token = aws_byte_cursor_from_string(s_client_token); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + test_fixture.header_check = secure_tunneling_client_token_check; + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_functionality_client_token_test, s_secure_tunneling_functionality_client_token_test_fn) + +/* [Func-UC3] */ + +int aws_websocket_client_connect_fail_once_fn(const struct aws_websocket_client_connection_options *options) { + struct aws_secure_tunnel *secure_tunnel = options->user_data; + struct aws_secure_tunnel_mock_test_fixture *test_fixture = secure_tunnel->config->user_data; + bool is_connection_failed_once = false; + + aws_mutex_lock(&test_fixture->lock); + is_connection_failed_once = test_fixture->secure_tunnel_connection_failed; + aws_mutex_unlock(&test_fixture->lock); + + if (is_connection_failed_once) { + if (!options->handshake_request) { + AWS_LOGF_ERROR( + AWS_LS_HTTP_WEBSOCKET_SETUP, + "id=static: Invalid connection options, missing required request for websocket client handshake."); + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); + } + + const struct aws_http_headers *request_headers = aws_http_message_get_headers(options->handshake_request); + if (test_fixture->header_check) { + ASSERT_SUCCESS(test_fixture->header_check(request_headers, test_fixture)); + } + + test_fixture->websocket_function_table->on_connection_setup_fn = options->on_connection_setup; + test_fixture->websocket_function_table->on_connection_shutdown_fn = options->on_connection_shutdown; + test_fixture->websocket_function_table->on_incoming_frame_begin_fn = options->on_incoming_frame_begin; + test_fixture->websocket_function_table->on_incoming_frame_payload_fn = options->on_incoming_frame_payload; + test_fixture->websocket_function_table->on_incoming_frame_complete_fn = options->on_incoming_frame_complete; + + void *pointer = test_fixture; + struct aws_websocket_on_connection_setup_data websocket_setup = {.error_code = AWS_ERROR_SUCCESS, + .websocket = pointer}; + + (test_fixture->websocket_function_table->on_connection_setup_fn)(&websocket_setup, secure_tunnel); + + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_byte_cursor service_2 = aws_byte_cursor_from_string(s_service_id_2); + struct aws_byte_cursor service_3 = aws_byte_cursor_from_string(s_service_id_3); + + struct aws_secure_tunnel_message_view service_ids_message = { + .type = AWS_SECURE_TUNNEL_MT_SERVICE_IDS, + .service_id = &service_1, + .service_id_2 = &service_2, + .service_id_3 = &service_3, + }; + + aws_secure_tunnel_send_mock_message(test_fixture, &service_ids_message); + + return AWS_OP_SUCCESS; + } else { + return AWS_OP_ERR; + } +} + +static int s_secure_tunneling_fail_and_retry_connection_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + test_fixture.secure_tunnel_vtable = *aws_secure_tunnel_get_default_vtable(); + test_fixture.secure_tunnel_vtable.aws_websocket_client_connect_fn = aws_websocket_client_connect_fail_once_fn; + test_fixture.secure_tunnel_vtable.aws_websocket_send_frame_fn = aws_websocket_send_frame_mock_fn; + test_fixture.secure_tunnel_vtable.aws_websocket_release_fn = aws_websocket_release_mock_fn; + test_fixture.secure_tunnel_vtable.aws_websocket_close_fn = aws_websocket_close_mock_fn; + test_fixture.secure_tunnel_vtable.vtable_user_data = &test_fixture; + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_fail_and_retry_connection_test, s_secure_tunneling_fail_and_retry_connection_test_fn) + +/* [Func-UC4] */ + +static int s_secure_tunneling_store_service_ids_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* check that service ids have been stored */ + struct aws_hash_element *elem = NULL; + struct aws_byte_cursor service_id_1_cur = aws_byte_cursor_from_string(s_service_id_1); + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_1_cur, &elem); + ASSERT_NOT_NULL(elem); + elem = NULL; + struct aws_byte_cursor service_id_2_cur = aws_byte_cursor_from_string(s_service_id_2); + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_2_cur, &elem); + ASSERT_NOT_NULL(elem); + elem = NULL; + struct aws_byte_cursor service_id_3_cur = aws_byte_cursor_from_string(s_service_id_3); + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_3_cur, &elem); + ASSERT_NOT_NULL(elem); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_store_service_ids_test, s_secure_tunneling_store_service_ids_test_fn) + +/* [Func-UC5] */ + +static int s_secure_tunneling_receive_stream_start_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + + /* check that service id stream has been set properly */ + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, stream_start_message_view.service_id, &elem); + ASSERT_NOT_NULL(elem); + struct aws_service_id_element *service_id_elem = elem->value; + ASSERT_TRUE(service_id_elem->stream_id == stream_start_message_view.stream_id); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_receive_stream_start_test, s_secure_tunneling_receive_stream_start_test_fn) + +/* [Func-UC6] */ + +static int s_secure_tunneling_rejected_service_id_stream_start_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a bad stream start message from the server to the destination client */ + struct aws_byte_cursor service_id = aws_byte_cursor_from_string(s_service_id_wrong); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_id, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a bad stream request was received */ + s_wait_for_bad_stream_request(&test_fixture); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_rejected_service_id_stream_start_test, + s_secure_tunneling_rejected_service_id_stream_start_test_fn) + +/* [Func-UC7] */ + +static int s_secure_tunneling_close_stream_on_stream_reset_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + + /* Send a stream reset message from the server to the destination client */ + stream_start_message_view.type = AWS_SECURE_TUNNEL_MT_STREAM_RESET; + + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait for a stream reset to have been received */ + s_wait_for_stream_reset_received(&test_fixture); + + /* check that service id stream has been reset */ + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, stream_start_message_view.service_id, &elem); + ASSERT_NOT_NULL(elem); + struct aws_service_id_element *service_id_elem = elem->value; + ASSERT_TRUE(service_id_elem->stream_id == 0); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_close_stream_on_stream_reset_test, + s_secure_tunneling_close_stream_on_stream_reset_test_fn) + +/* [Func-UC8] */ +static int s_secure_tunneling_session_reset_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_byte_cursor service_2 = aws_byte_cursor_from_string(s_service_id_2); + struct aws_byte_cursor service_3 = aws_byte_cursor_from_string(s_service_id_3); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + stream_start_message_view.service_id = &service_2; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + stream_start_message_view.service_id = &service_3; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + test_fixture.secure_tunnel_stream_started_count_target = 3; + s_wait_for_n_stream_started(&test_fixture); + + /* check that stream ids have been set */ + struct aws_hash_element *elem = NULL; + struct aws_byte_cursor service_id_1_cur = aws_byte_cursor_from_string(s_service_id_1); + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_1_cur, &elem); + ASSERT_NOT_NULL(elem); + struct aws_service_id_element *service_id_elem = elem->value; + ASSERT_TRUE(service_id_elem->stream_id == stream_start_message_view.stream_id); + elem = NULL; + struct aws_byte_cursor service_id_2_cur = aws_byte_cursor_from_string(s_service_id_2); + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_2_cur, &elem); + ASSERT_NOT_NULL(elem); + service_id_elem = elem->value; + ASSERT_TRUE(service_id_elem->stream_id == stream_start_message_view.stream_id); + elem = NULL; + struct aws_byte_cursor service_id_3_cur = aws_byte_cursor_from_string(s_service_id_3); + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_3_cur, &elem); + ASSERT_NOT_NULL(elem); + service_id_elem = elem->value; + ASSERT_TRUE(service_id_elem->stream_id == stream_start_message_view.stream_id); + + /* Create and send a session reset message from the server to the destination client */ + struct aws_secure_tunnel_message_view reset_message_view = { + .type = AWS_SECURE_TUNNEL_MT_SESSION_RESET, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &reset_message_view); + + s_wait_for_session_reset_received(&test_fixture); + + /* Check that stream ids have been reset */ + elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_1_cur, &elem); + ASSERT_NOT_NULL(elem); + service_id_elem = elem->value; + ASSERT_TRUE(service_id_elem->stream_id == 0); + elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_2_cur, &elem); + ASSERT_NOT_NULL(elem); + service_id_elem = elem->value; + ASSERT_TRUE(service_id_elem->stream_id == 0); + elem = NULL; + aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_3_cur, &elem); + ASSERT_NOT_NULL(elem); + service_id_elem = elem->value; + ASSERT_TRUE(service_id_elem->stream_id == 0); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_session_reset_test, s_secure_tunneling_session_reset_test_fn) + +static int s_secure_tunneling_serializer_data_message_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + struct secure_tunnel_test_options test_options; + s_secure_tunnel_test_init_default_options(&test_options); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options.secure_tunnel_options, + .websocket_function_table = &test_options.websocket_function_table, + }; + + struct aws_secure_tunnel_mock_test_fixture test_fixture; + ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Create and send a data message from the server to the destination client */ + struct aws_byte_cursor payload_cur = aws_byte_cursor_from_string(s_payload_text); + struct aws_secure_tunnel_message_view data_message_view = { + .type = AWS_SECURE_TUNNEL_MT_DATA, + .service_id = &service_1, + .stream_id = 1, + .payload = &payload_cur, + }; + + aws_secure_tunnel_send_mock_message(&test_fixture, &data_message_view); + test_fixture.secure_tunnel_message_count_target = 1; + s_wait_for_n_messages_received(&test_fixture); + + struct aws_byte_cursor payload_comp_cur = { + .ptr = test_fixture.last_message_payload_buf.buffer, + .len = test_fixture.last_message_payload_buf.len, + }; + ASSERT_CURSOR_VALUE_STRING_EQUALS(payload_comp_cur, s_payload_text); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_release(secure_tunnel); + s_wait_for_secure_tunnel_terminated(&test_fixture); + + aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_serializer_data_message_test, s_secure_tunneling_serializer_data_message_test_fn) diff --git a/tests/secure_tunneling_tests.c b/tests/secure_tunneling_tests.c deleted file mode 100644 index 7a6bdc0a..00000000 --- a/tests/secure_tunneling_tests.c +++ /dev/null @@ -1,717 +0,0 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#define UNUSED(x) (void)(x) - -#define INVALID_STREAM_ID 0 -#define STREAM_ID 10 -#define ACCESS_TOKEN "my_super_secret_access_token" -#define ENDPOINT "data.tunneling.iot.us-west-2.amazonaws.com" -#define PAYLOAD "secure tunneling data payload" - -/* - * The tests here call these functions directly. - */ - -struct secure_tunneling_test_context { - enum aws_secure_tunneling_local_proxy_mode local_proxy_mode; - uint16_t max_threads; - struct aws_event_loop_group *elg; - struct aws_host_resolver *resolver; - struct aws_client_bootstrap *bootstrap; - struct aws_secure_tunnel *secure_tunnel; -}; -static struct secure_tunneling_test_context s_test_context = {.max_threads = 1}; - -static bool s_on_stream_start_called = false; -static void s_on_stream_start(void *user_data) { - UNUSED(user_data); - s_on_stream_start_called = true; -} - -static bool s_on_data_receive_correct_payload = false; -static void s_on_data_receive(const struct aws_byte_buf *data, void *user_data) { - UNUSED(user_data); - s_on_data_receive_correct_payload = aws_byte_buf_eq_c_str(data, PAYLOAD); -} - -static bool s_on_stream_reset_called = false; -static void s_on_stream_reset(void *user_data) { - UNUSED(user_data); - s_on_stream_reset_called = true; -} - -static bool s_on_session_reset_called = false; -static void s_on_session_reset(void *user_data) { - UNUSED(user_data); - s_on_session_reset_called = true; -} - -static bool s_on_termination_complete_called = false; -static void s_on_termination_complete(void *user_data) { - UNUSED(user_data); - s_on_termination_complete_called = true; -} - -static void s_init_secure_tunneling_connection_config( - struct aws_allocator *allocator, - struct aws_client_bootstrap *bootstrap, - struct aws_socket_options *socket_options, - const char *access_token, - enum aws_secure_tunneling_local_proxy_mode local_proxy_mode, - const char *endpoint, - struct aws_secure_tunnel_options *options) { - - AWS_ZERO_STRUCT(*options); - options->allocator = allocator; - options->bootstrap = bootstrap; - options->socket_options = socket_options; - - options->access_token = aws_byte_cursor_from_c_str(access_token); - options->local_proxy_mode = local_proxy_mode; - options->endpoint_host = aws_byte_cursor_from_c_str(endpoint); - - options->on_stream_start = s_on_stream_start; - options->on_data_receive = s_on_data_receive; - options->on_stream_reset = s_on_stream_reset; - options->on_session_reset = s_on_session_reset; - options->on_termination_complete = s_on_termination_complete; - /* TODO: Initialize the rest of the callbacks */ -} - -/* - * Mock aws websocket api used by the secure tunnel. - */ - -int s_mock_aws_websocket_client_connect(const struct aws_websocket_client_connection_options *options) { - UNUSED(options); - return AWS_OP_SUCCESS; -} - -static size_t s_mock_aws_websocket_send_frame_call_count = 0U; - -static size_t s_mock_aws_websocket_send_frame_payload_len = 0U; - -int s_mock_aws_websocket_send_frame( - struct aws_websocket *websocket, - const struct aws_websocket_send_frame_options *options) { - UNUSED(websocket); - ++s_mock_aws_websocket_send_frame_call_count; - - struct data_tunnel_pair *pair = (struct data_tunnel_pair *)options->user_data; - struct aws_byte_buf *buf = &pair->buf; - struct aws_byte_cursor cursor = aws_byte_cursor_from_buf(buf); - - /* Deserialize the wire format to obtain original payload. */ - struct aws_iot_st_msg message; - int rc = aws_iot_st_msg_deserialize_from_cursor(&message, &cursor, s_test_context.secure_tunnel->alloc); - ASSERT_INT_EQUALS(AWS_OP_SUCCESS, rc); - s_mock_aws_websocket_send_frame_payload_len += message.payload.len; - aws_byte_buf_clean_up(&message.payload); - - /* Deallocate memory for the buffer holding the wire protocol data and the tunnel context. */ - aws_byte_buf_clean_up(buf); - aws_mem_release(s_test_context.secure_tunnel->alloc, pair); - - return AWS_OP_SUCCESS; -} - -void s_mock_aws_websocket_close(struct aws_websocket *websocket, bool free_scarce_resources_immediately) { - UNUSED(websocket); - UNUSED(free_scarce_resources_immediately); -} - -void s_mock_aws_websocket_release(struct aws_websocket *websocket) { - UNUSED(websocket); - /* Release the handshake_request. In a non-mocked context this would occur after handshake completes. */ - aws_http_message_release(s_test_context.secure_tunnel->handshake_request); -} - -/* s_secure_tunnel_new_mock returns a secure_tunnel that mocks the aws websocket public api. */ -static struct aws_secure_tunnel *s_secure_tunnel_new_mock(const struct aws_secure_tunnel_options *options) { - struct aws_secure_tunnel *secure_tunnel = aws_secure_tunnel_new(options); - if (!secure_tunnel) { - return secure_tunnel; - } - secure_tunnel->websocket_vtable.client_connect = s_mock_aws_websocket_client_connect; - secure_tunnel->websocket_vtable.send_frame = s_mock_aws_websocket_send_frame; - secure_tunnel->websocket_vtable.close = s_mock_aws_websocket_close; - secure_tunnel->websocket_vtable.release = s_mock_aws_websocket_release; - - /* - * Initialize a dummy websocket when the tunnel is created. - * - * In the non-mock implementation this websocket would be initialized when - * an http upgrade request is received sometime after the tunnel is created. - * - * Since no http request is exercised by these tests we initialize a websocket - * as soon as the tunnel is created. Since the aws_websocket struct is opaque - * to this module, we use the placeholder value 1 to set the member non-null. - */ - secure_tunnel->websocket = (void *)1; - - return secure_tunnel; -} - -static int before(struct aws_allocator *allocator, void *ctx) { - struct secure_tunneling_test_context *test_context = ctx; - - /* Initialize aws-c-iot library. */ - aws_iotdevice_library_init(allocator); - - /* Initialize event loop. */ - test_context->elg = aws_event_loop_group_new_default(allocator, test_context->max_threads, NULL); - - /* Initialize dns resolver. */ - struct aws_host_resolver_default_options host_resolver_default_options; - AWS_ZERO_STRUCT(host_resolver_default_options); - host_resolver_default_options.max_entries = 8; - host_resolver_default_options.el_group = test_context->elg; - host_resolver_default_options.shutdown_options = NULL; - host_resolver_default_options.system_clock_override_fn = NULL; - test_context->resolver = aws_host_resolver_new_default(allocator, &host_resolver_default_options); - - /* Initialize client_bootstrap with event loop and dns resolver. */ - struct aws_client_bootstrap_options bootstrap_options = { - .event_loop_group = test_context->elg, - .host_resolver = test_context->resolver, - }; - test_context->bootstrap = aws_client_bootstrap_new(allocator, &bootstrap_options); - - /* Initialize socket_options for secure tunnel. */ - struct aws_socket_options socket_options; - AWS_ZERO_STRUCT(socket_options); - - /* Initialize secure_tunnel_options with client_bootstrap. */ - struct aws_secure_tunnel_options options; - s_init_secure_tunneling_connection_config( - allocator, - test_context->bootstrap, - &socket_options, - ACCESS_TOKEN, - test_context->local_proxy_mode, - ENDPOINT, - &options); - - /* Initialize secure_tunnel. */ - test_context->secure_tunnel = s_secure_tunnel_new_mock(&options); - ASSERT_NOT_NULL(test_context->secure_tunnel); - - return AWS_OP_SUCCESS; -} - -static int after(struct aws_allocator *allocator, int setup_result, void *ctx) { - UNUSED(allocator); - UNUSED(setup_result); - - struct secure_tunneling_test_context *test_context = ctx; - - aws_host_resolver_release(test_context->resolver); - aws_event_loop_group_release(test_context->elg); - aws_client_bootstrap_release(test_context->bootstrap); - - aws_secure_tunnel_release(test_context->secure_tunnel); - - /* - * Under normal circumstances you would need to do a wait on a condition variable signal for the async - * destruction process to unwind. But these tests all use mocks so termination completion happens synchronously - */ - ASSERT_TRUE(s_on_termination_complete_called); - - aws_thread_join_all_managed(); - - aws_iotdevice_library_clean_up(); - - return AWS_OP_SUCCESS; -} - -static void s_send_secure_tunneling_frame_to_websocket( - const struct aws_iot_st_msg *st_msg, - struct aws_allocator *allocator, - struct aws_secure_tunnel *secure_tunnel) { - - struct aws_byte_buf serialized_st_msg; - aws_iot_st_msg_serialize_from_struct(&serialized_st_msg, allocator, *st_msg); - - /* Prepend 2 bytes length */ - struct aws_byte_buf websocket_frame; - aws_byte_buf_init(&websocket_frame, allocator, serialized_st_msg.len + 2); - aws_byte_buf_write_be16(&websocket_frame, (uint16_t)serialized_st_msg.len); - struct aws_byte_cursor c = aws_byte_cursor_from_buf(&serialized_st_msg); - aws_byte_buf_append(&websocket_frame, &c); - c = aws_byte_cursor_from_buf(&websocket_frame); - - on_websocket_incoming_frame_payload(NULL, NULL, c, secure_tunnel); - - aws_byte_buf_clean_up(&serialized_st_msg); - aws_byte_buf_clean_up(&websocket_frame); -} - -static int s_test_sent_data( - struct secure_tunneling_test_context *test_context, - const char *expected_payload, - const int32_t expected_stream_id, - const int prefix_bytes, - const enum aws_iot_st_message_type type) { - /* - * The public api used to send data over a secure tunnel is aws_secure_tunnel_send_data. - * - * 1/ The public api accepts an aws_byte_cursor and logically splits this cursor into smaller - * nonoverlapping cursors aka frames using the private secure_tunneling_init_send_frame function. - * - * 2/ Each frame is written to the websocket connection using the public api aws_websocket_send_frame. - * The websocket api pushes the frame on a fifo queue of frames managed by the event loop. - * The event loop thread pops frames from the queue and writes the data to the websocket connection. - * - * The function implemented below differs from the public api in some important ways. - * - * 1/ This function frames a aws_byte_cursor by calling the private api secure_tunneling_init_send_frame. - * As a result, this function does not exercise the logic in the public api aws_secure_tunnel_send_data - * to split the input cursor into multiple frames. - * - * 2/ This function does not exercise any of the websocket api. Instead this function calls a private api - * secure_tunneling_send_data_call with the websocket set to NULL. Instead of queueing frames, this - * api writes frames to a second buffer in the websocket wire protocol format. The test compares the - * second buffer to what is expected from the wire format. In the public api, the functionality to - * write the frame in the websocket wire protocol format is invoked as a callback from the event loop. - * - * To summarize, the test below has value, but such value is limited by carefully avoiding the public - * api to send data over a secure tunnel. A separate group of tests are required to more directly - * exercise the public api. - * - */ - - struct aws_iot_st_msg message; - message.type = type; - message.stream_id = expected_stream_id; - message.ignorable = 0; - message.payload = aws_byte_buf_from_c_str(expected_payload); - struct aws_byte_buf serialized_st_msg; - aws_iot_st_msg_serialize_from_struct(&serialized_st_msg, test_context->secure_tunnel->options->allocator, message); - - struct aws_byte_cursor cur = aws_byte_cursor_from_c_str(expected_payload); - struct aws_websocket_send_frame_options frame_options; - ASSERT_INT_EQUALS( - AWS_OP_SUCCESS, - secure_tunneling_init_send_frame(&frame_options, test_context->secure_tunnel, &cur, message.type)); - - ASSERT_INT_EQUALS(serialized_st_msg.len + prefix_bytes, frame_options.payload_length); - - struct aws_byte_buf out_buf; - ASSERT_INT_EQUALS( - AWS_OP_SUCCESS, - aws_byte_buf_init( - &out_buf, test_context->secure_tunnel->options->allocator, (size_t)frame_options.payload_length)); - - ASSERT_TRUE(secure_tunneling_send_data_call(NULL, &out_buf, frame_options.user_data)); - struct aws_byte_cursor out_buf_cur = aws_byte_cursor_from_buf(&out_buf); - - ASSERT_UINT_EQUALS(out_buf_cur.len - prefix_bytes, serialized_st_msg.len); - - uint16_t payload_prefixed_length; - aws_byte_cursor_read_be16(&out_buf_cur, &payload_prefixed_length); - ASSERT_UINT_EQUALS((uint16_t)serialized_st_msg.len, payload_prefixed_length); - ASSERT_BIN_ARRAYS_EQUALS(serialized_st_msg.buffer, serialized_st_msg.len, out_buf_cur.ptr, out_buf_cur.len); - - struct data_tunnel_pair *pair = frame_options.user_data; - aws_byte_buf_clean_up(&pair->buf); - aws_mem_release(pair->secure_tunnel->options->allocator, (void *)pair); - aws_byte_buf_clean_up(&serialized_st_msg); - aws_byte_buf_clean_up(&out_buf); - - return AWS_OP_SUCCESS; -} - -static int s_byte_buf_init_rand(struct aws_byte_buf *buf, struct aws_allocator *allocator, size_t capacity) { - int rc = aws_byte_buf_init(buf, allocator, capacity); - if (rc != AWS_OP_SUCCESS) { - return rc; - } - return aws_device_random_buffer(buf); -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_stream_start_test, - before, - s_secure_tunneling_handle_stream_start_test, - after, - &s_test_context); -static int s_secure_tunneling_handle_stream_start_test(struct aws_allocator *allocator, void *ctx) { - /* - * When secure tunnel running in destination mode receives a StreamStart message, - * verify the stream start callback is invoked and that the stream ID is parsed from the message. - */ - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_DESTINATION_MODE; - - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_START; - st_msg.stream_id = STREAM_ID; - s_on_stream_start_called = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, test_context->secure_tunnel); - - ASSERT_TRUE(s_on_stream_start_called); - ASSERT_INT_EQUALS(STREAM_ID, test_context->secure_tunnel->stream_id); - ASSERT_UINT_EQUALS(0, test_context->secure_tunnel->received_data.len); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_data_receive_test, - before, - s_secure_tunneling_handle_data_receive_test, - after, - &s_test_context); -static int s_secure_tunneling_handle_data_receive_test(struct aws_allocator *allocator, void *ctx) { - /* - * When secure tunnel running in destination mode receives a Data message, - * verify the data callback is invoked with matching message payload. - */ - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_DESTINATION_MODE; - - /* Send StreamStart first */ - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_START; - st_msg.stream_id = STREAM_ID; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, test_context->secure_tunnel); - - /* Send data */ - AWS_ZERO_STRUCT(st_msg); - st_msg.type = DATA; - st_msg.stream_id = STREAM_ID; - st_msg.payload = aws_byte_buf_from_c_str(PAYLOAD); - s_on_data_receive_correct_payload = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, test_context->secure_tunnel); - - ASSERT_TRUE(s_on_data_receive_correct_payload); - ASSERT_INT_EQUALS(STREAM_ID, test_context->secure_tunnel->stream_id); - ASSERT_UINT_EQUALS(0, test_context->secure_tunnel->received_data.len); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_stream_reset_test, - before, - s_secure_tunneling_handle_stream_reset_test, - after, - &s_test_context); -static int s_secure_tunneling_handle_stream_reset_test(struct aws_allocator *allocator, void *ctx) { - /* - * When secure tunnel running in destination mode receives a StreamReset message, - * verify the stream reset callback is invoked and the stream ID is unset. - */ - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_DESTINATION_MODE; - - /* Send StreamStart first */ - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_START; - st_msg.stream_id = STREAM_ID; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, test_context->secure_tunnel); - - /* Send StreamReset */ - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_RESET; - st_msg.stream_id = STREAM_ID; - s_on_stream_reset_called = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, test_context->secure_tunnel); - - ASSERT_TRUE(s_on_stream_reset_called); - ASSERT_INT_EQUALS(INVALID_STREAM_ID, test_context->secure_tunnel->stream_id); - ASSERT_UINT_EQUALS(0, test_context->secure_tunnel->received_data.len); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_session_reset_test, - before, - s_secure_tunneling_handle_session_reset_test, - after, - &s_test_context); -static int s_secure_tunneling_handle_session_reset_test(struct aws_allocator *allocator, void *ctx) { - /* - * When secure tunnel running in destination mode receives a SessionReset message with a valid stream ID, - * verify the session reset callback is invoked and the stream ID is unset. - */ - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_DESTINATION_MODE; - - /* Send StreamStart first */ - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_START; - st_msg.stream_id = STREAM_ID; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, test_context->secure_tunnel); - - /* Send StreamReset */ - AWS_ZERO_STRUCT(st_msg); - st_msg.type = SESSION_RESET; - st_msg.stream_id = STREAM_ID; - s_on_session_reset_called = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, test_context->secure_tunnel); - - ASSERT_TRUE(s_on_session_reset_called); - ASSERT_INT_EQUALS(INVALID_STREAM_ID, test_context->secure_tunnel->stream_id); - ASSERT_UINT_EQUALS(0, test_context->secure_tunnel->received_data.len); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_session_reset_no_stream_test, - before, - s_secure_tunneling_handle_session_reset_no_stream_test, - after, - &s_test_context); -static int s_secure_tunneling_handle_session_reset_no_stream_test(struct aws_allocator *allocator, void *ctx) { - /* - * When secure tunnel running in destination mode receives a SessionReset message without valid stream ID, - * verify the session reset callback is not invoked. - */ - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_DESTINATION_MODE; - - /* Send StreamReset without existing stream */ - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = SESSION_RESET; - s_on_session_reset_called = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, test_context->secure_tunnel); - - ASSERT_FALSE(s_on_session_reset_called); - ASSERT_INT_EQUALS(INVALID_STREAM_ID, test_context->secure_tunnel->stream_id); - ASSERT_UINT_EQUALS(0, test_context->secure_tunnel->received_data.len); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_init_websocket_options_test, - before, - s_secure_tunneling_init_websocket_options_test, - after, - &s_test_context); -static int s_secure_tunneling_init_websocket_options_test(struct aws_allocator *allocator, void *ctx) { - /* - * When a client connects to a websocket server, - * verify the client handshake includes the aws secure tunneling protocol string and access token - * provided by the secure tunneling service when the tunnel is provisioned. - */ - - UNUSED(allocator); - - struct secure_tunneling_test_context *test_context = ctx; - - struct aws_websocket_client_connection_options websocket_options; - init_websocket_client_connection_options(test_context->secure_tunnel, &websocket_options); - - ASSERT_TRUE(aws_byte_cursor_eq_c_str(&websocket_options.host, ENDPOINT)); - - /* - * Verify handshake request - */ - - ASSERT_TRUE(aws_http_message_is_request(websocket_options.handshake_request)); - - struct aws_byte_cursor method; - aws_http_message_get_request_method(websocket_options.handshake_request, &method); - ASSERT_TRUE(aws_byte_cursor_eq_c_str(&method, "GET")); - - /* Verify path */ - struct aws_byte_cursor path; - aws_http_message_get_request_path(websocket_options.handshake_request, &path); - ASSERT_TRUE(aws_byte_cursor_eq_c_str(&path, "/tunnel?local-proxy-mode=source")); - - /* Verify headers */ - const char *expected_headers[][2] = { - {"Sec-WebSocket-Protocol", "aws.iot.securetunneling-1.0"}, - {"access-token", ACCESS_TOKEN}, - }; - - const struct aws_http_headers *headers = aws_http_message_get_const_headers(websocket_options.handshake_request); - for (size_t i = 0; i < sizeof(expected_headers) / sizeof(expected_headers[0]); i++) { - struct aws_byte_cursor name = aws_byte_cursor_from_c_str(expected_headers[i][0]); - struct aws_byte_cursor value; - ASSERT_INT_EQUALS(AWS_OP_SUCCESS, aws_http_headers_get(headers, name, &value)); - ASSERT_TRUE(aws_byte_cursor_eq_c_str(&value, expected_headers[i][1])); - } - - aws_http_message_release(websocket_options.handshake_request); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_send_data, - before, - s_secure_tunneling_handle_send_data, - after, - &s_test_context); -static int s_secure_tunneling_handle_send_data(struct aws_allocator *allocator, void *ctx) { - /* - * When a secure tunnel running in source mode sends data to destination, - * verify the data are written to the tunnel in the expected websocket wire protocol format. - */ - - UNUSED(allocator); - const char *expected_payload = "Hi! I'm Paul / Some random payload\n"; - const int32_t expected_stream_id = 1; - const int prefix_bytes = 2; - const enum aws_iot_st_message_type type = DATA; - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_SOURCE_MODE; - test_context->secure_tunnel->stream_id = expected_stream_id; - - s_test_sent_data(test_context, expected_payload, expected_stream_id, prefix_bytes, type); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_send_data_stream_start, - before, - s_secure_tunneling_handle_send_data_stream_start, - after, - &s_test_context); -static int s_secure_tunneling_handle_send_data_stream_start(struct aws_allocator *allocator, void *ctx) { - /* - * When a secure tunnel running in source mode sends StreamStart to destination, - * verify the data are written to the tunnel in the expected websocket wire protocol format. - */ - - UNUSED(allocator); - const char *expected_payload = ""; - const int32_t expected_stream_id = 1; - const int prefix_bytes = 2; - const enum aws_iot_st_message_type type = STREAM_START; - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_SOURCE_MODE; - test_context->secure_tunnel->stream_id = expected_stream_id; - - s_test_sent_data(test_context, expected_payload, expected_stream_id, prefix_bytes, type); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_send_data_stream_reset, - before, - s_secure_tunneling_handle_send_data_stream_reset, - after, - &s_test_context); -static int s_secure_tunneling_handle_send_data_stream_reset(struct aws_allocator *allocator, void *ctx) { - /* - * When a secure tunnel running in source mode sends StreamReset to destination, - * verify the data are written to the tunnel in the expected websocket wire protocol format. - */ - - UNUSED(allocator); - const char *expected_payload = ""; - const int32_t expected_stream_id = 1; - const int prefix_bytes = 2; - const enum aws_iot_st_message_type type = STREAM_RESET; - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_SOURCE_MODE; - test_context->secure_tunnel->stream_id = expected_stream_id; - - s_test_sent_data(test_context, expected_payload, expected_stream_id, prefix_bytes, type); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - secure_tunneling_handle_send_data_public, - before, - s_secure_tunneling_handle_send_data_public, - after, - &s_test_context); -static int s_secure_tunneling_handle_send_data_public(struct aws_allocator *allocator, void *ctx) { - /* - * When a secure tunnel running in source mode sends data to destination using the public api, - * verify that the payload length matches what the client sends and the number of frames sent - * is equal to size of the payload divided by the maximum frame length. - */ - - struct secure_tunneling_test_context *test_context = ctx; - test_context->secure_tunnel->options->local_proxy_mode = AWS_SECURE_TUNNELING_SOURCE_MODE; - - /* Open the tunnel. */ - int rc = aws_secure_tunnel_connect(test_context->secure_tunnel); - ASSERT_INT_EQUALS(AWS_OP_SUCCESS, rc); - - size_t buf_sizes[] = {10, 100, 1000, AWS_IOT_ST_SPLIT_MESSAGE_SIZE + 1, 2 * AWS_IOT_ST_SPLIT_MESSAGE_SIZE + 1}; - size_t buf_sizes_len = sizeof(buf_sizes) / sizeof(buf_sizes[0]); - - for (size_t i = 0; i < buf_sizes_len; ++i) { - /* Start a stream. */ - s_mock_aws_websocket_send_frame_call_count = 0U; - s_mock_aws_websocket_send_frame_payload_len = 0U; - rc = aws_secure_tunnel_stream_start(test_context->secure_tunnel); - ASSERT_INT_EQUALS(AWS_OP_SUCCESS, rc); - ASSERT_UINT_EQUALS(1U, s_mock_aws_websocket_send_frame_call_count); - ASSERT_UINT_EQUALS(0U, s_mock_aws_websocket_send_frame_payload_len); - - /* Initialize buffer of random values to send. */ - struct aws_byte_buf buf; - rc = s_byte_buf_init_rand(&buf, allocator, buf_sizes[i]); - ASSERT_INT_EQUALS(AWS_OP_SUCCESS, rc); - - struct aws_byte_cursor cur = aws_byte_cursor_from_buf(&buf); - - /* Call public api to send data over secure tunnel. */ - s_mock_aws_websocket_send_frame_call_count = 0U; - s_mock_aws_websocket_send_frame_payload_len = 0U; - rc = aws_secure_tunnel_send_data(test_context->secure_tunnel, &cur); - ASSERT_INT_EQUALS(AWS_OP_SUCCESS, rc); - int expected_call_count = (int)buf_sizes[i] / AWS_IOT_ST_SPLIT_MESSAGE_SIZE + 1; - ASSERT_UINT_EQUALS(expected_call_count, s_mock_aws_websocket_send_frame_call_count); - ASSERT_UINT_EQUALS(buf_sizes[i], s_mock_aws_websocket_send_frame_payload_len); - - /* Free buffer. */ - aws_byte_buf_clean_up(&buf); - } - - /* Close the tunnel. */ - aws_secure_tunnel_close(test_context->secure_tunnel); - - return AWS_OP_SUCCESS; -} diff --git a/tests/tests_protobuf/aws_iot_st_pb_test.cpp b/tests/tests_protobuf/aws_iot_st_pb_test.cpp index 14c0685c..1d18bef0 100644 --- a/tests/tests_protobuf/aws_iot_st_pb_test.cpp +++ b/tests/tests_protobuf/aws_iot_st_pb_test.cpp @@ -30,7 +30,7 @@ static int execute_tests( protobufMessage.ParseFromString(pbBuffer); struct aws_iot_st_msg c_message; - c_message.type = (aws_iot_st_message_type)type; + c_message.type = (aws_secure_tunnel_message_type)type; c_message.stream_id = streamid; c_message.ignorable = ignorable; c_message.payload = aws_byte_buf_from_c_str(payload.c_str()); @@ -173,4 +173,4 @@ int main(int argc, char *argv[]) { } google::protobuf::ShutdownProtobufLibrary(); return AWS_OP_SUCCESS; -} \ No newline at end of file +}