diff --git a/.gitignore b/.gitignore index 5a5765c6..b182cbc6 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ INSTALL sniproxy-*.tar.gz tags test-driver +.DS_Store diff --git a/src/Makefile.am b/src/Makefile.am index 456a5d20..2f6fb877 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -32,6 +32,10 @@ sniproxy_SOURCES = sniproxy.c \ table.c \ table.h \ tls.c \ - tls.h + tls.h \ + dtls.c \ + dtls.h \ + sni.c \ + sni.h sniproxy_LDADD = $(LIBEV_LIBS) $(LIBPCRE_LIBS) $(LIBUDNS_LIBS) diff --git a/src/binder.c b/src/binder.c index 53ea3fca..af586b94 100644 --- a/src/binder.c +++ b/src/binder.c @@ -44,6 +44,7 @@ static int parse_ancillary_data(struct msghdr *); struct binder_request { + int type; size_t address_len; struct sockaddr address[]; }; @@ -82,7 +83,7 @@ start_binder() { } int -bind_socket(const struct sockaddr *addr, size_t addr_len) { +bind_socket(int type, const struct sockaddr *addr, size_t addr_len) { struct binder_request *request; struct msghdr msg; struct iovec iov[1]; @@ -99,6 +100,7 @@ bind_socket(const struct sockaddr *addr, size_t addr_len) { if (request_len > sizeof(data_buf)) fatal("bind_socket: request length %zu exceeds buffer", request_len); request = (struct binder_request *)data_buf; + request->type = type; request->address_len = addr_len; memcpy(&request->address, addr, addr_len); @@ -159,7 +161,7 @@ binder_main(int sockfd) { struct binder_request *req = (struct binder_request *)buffer; - int fd = socket(req->address[0].sa_family, SOCK_STREAM, 0); + int fd = socket(req->address[0].sa_family, req->type, 0); if (fd < 0) { memset(buffer, 0, sizeof(buffer)); snprintf(buffer, sizeof(buffer), "socket(): %s", strerror(errno)); diff --git a/src/binder.h b/src/binder.h index 91a838d3..90c77531 100644 --- a/src/binder.h +++ b/src/binder.h @@ -29,7 +29,7 @@ #include void start_binder(); -int bind_socket(const struct sockaddr *, size_t); +int bind_socket(int type, const struct sockaddr *, size_t); void stop_binder(); #endif diff --git a/src/buffer.c b/src/buffer.c index 23a609ca..d9ad6f2c 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -26,6 +26,7 @@ #include #include /* malloc */ #include /* memcpy */ +#include /* uint16_t */ #include #include #include @@ -38,25 +39,30 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define NOT_POWER_OF_2(x) (x == 0 || (x & (x - 1))) +#define ALIGN_16BIT(x) ((x + 1) & ~((typeof(x))1)) + +/* REMOVE */ +#define DBG_DUMP(x, y) if (x->type == SOCK_DGRAM) y static const size_t BUFFER_MAX_SIZE = 1024 * 1024 * 1024; -static size_t setup_write_iov(const struct Buffer *, struct iovec *, size_t); -static size_t setup_read_iov(const struct Buffer *, struct iovec *, size_t); +static size_t setup_write_iov(const struct Buffer *, struct iovec [2], size_t); +static size_t setup_read_iov(const struct Buffer *, struct iovec [2], size_t); static inline void advance_write_position(struct Buffer *, size_t); static inline void advance_read_position(struct Buffer *, size_t); struct Buffer * -new_buffer(size_t size, struct ev_loop *loop) { +new_buffer(int type, size_t size, struct ev_loop *loop) { if (NOT_POWER_OF_2(size)) return NULL; struct Buffer *buf = malloc(sizeof(struct Buffer)); if (buf == NULL) return NULL; + buf->type = type; buf->size_mask = size - 1; buf->len = 0; buf->head = 0; @@ -98,6 +104,15 @@ buffer_resize(struct Buffer *buf, size_t new_size) { return (ssize_t)buf->len; } +ssize_t +buffer_copy_data(struct Buffer *src, struct Buffer *dst, size_t len) +{ + /* Raw copy the entire buffer */ + memcpy(src->buffer, dst->buffer, len); + + return 0; +} + void free_buffer(struct Buffer *buf) { if (buf == NULL) @@ -109,21 +124,27 @@ free_buffer(struct Buffer *buf) { ssize_t buffer_recv(struct Buffer *buffer, int sockfd, int flags, struct ev_loop *loop) { + struct msghdr msg = { 0 }; + + return buffer_recvmsg(buffer, sockfd, &msg, flags, loop); +} + +ssize_t +buffer_recvmsg(struct Buffer *buffer, int sockfd, struct msghdr *msg, + int flags, struct ev_loop *loop) { /* coalesce when reading into an empty buffer */ if (buffer->len == 0) buffer->head = 0; struct iovec iov[2]; - struct msghdr msg = { - .msg_iov = iov, - .msg_iovlen = setup_write_iov(buffer, iov, 0) - }; + msg->msg_iov = iov; + msg->msg_iovlen = setup_write_iov(buffer, iov, 0); - ssize_t bytes = recvmsg(sockfd, &msg, flags); + ssize_t bytes = recvmsg(sockfd, msg, flags); buffer->last_recv = ev_now(loop); - if (bytes > 0) + if (bytes >= 0) advance_write_position(buffer, (size_t)bytes); return bytes; @@ -131,17 +152,22 @@ buffer_recv(struct Buffer *buffer, int sockfd, int flags, struct ev_loop *loop) ssize_t buffer_send(struct Buffer *buffer, int sockfd, int flags, struct ev_loop *loop) { + struct msghdr msg = { 0 }; + + return buffer_sendmsg(buffer, sockfd, &msg, flags, loop); +} + +ssize_t +buffer_sendmsg(struct Buffer *buffer, int sockfd, struct msghdr *msg, int flags, struct ev_loop *loop) { struct iovec iov[2]; - struct msghdr msg = { - .msg_iov = iov, - .msg_iovlen = setup_read_iov(buffer, iov, 0) - }; + msg->msg_iov = iov; + msg->msg_iovlen = setup_read_iov(buffer, iov, 0); - ssize_t bytes = sendmsg(sockfd, &msg, flags); + ssize_t bytes = sendmsg(sockfd, msg, flags); buffer->last_send = ev_now(loop); - if (bytes > 0) + if (bytes >= 0) advance_read_position(buffer, (size_t)bytes); return bytes; @@ -160,7 +186,7 @@ buffer_read(struct Buffer *buffer, int fd) { size_t iov_len = setup_write_iov(buffer, iov, 0); ssize_t bytes = readv(fd, iov, iov_len); - if (bytes > 0) + if (bytes >= 0) advance_write_position(buffer, (size_t)bytes); return bytes; @@ -175,7 +201,7 @@ buffer_write(struct Buffer *buffer, int fd) { size_t iov_len = setup_read_iov(buffer, iov, 0); ssize_t bytes = writev(fd, iov, iov_len); - if (bytes > 0) + if (bytes >= 0) advance_read_position(buffer, (size_t)bytes); return bytes; @@ -190,30 +216,65 @@ buffer_write(struct Buffer *buffer, int fd) { size_t buffer_coalesce(struct Buffer *buffer, const void **dst) { size_t buffer_tail = (buffer->head + buffer->len) & buffer->size_mask; + size_t head = buffer->head; - if (buffer_tail <= buffer->head) { + if (buffer_tail >= buffer->head) { /* buffer not wrapped */ + if (buffer->type == SOCK_DGRAM) + head += sizeof(uint16_t); if (dst != NULL) - *dst = &buffer->buffer[buffer->head]; + *dst = &buffer->buffer[head]; return buffer->len; } else { - /* buffer wrapped */ - size_t len = buffer->len; - char *temp = malloc(len); - if (temp != NULL) { - buffer_pop(buffer, temp, len); - assert(buffer->len == 0); - - buffer_push(buffer, temp, len); - assert(buffer->head == 0); - assert(buffer->len == len); - - free(temp); + if (buffer->type == SOCK_STREAM) { + /* buffer wrapped */ + size_t len = buffer->len; + char *temp = malloc(len); + if (temp != NULL) { + buffer_pop(buffer, temp, len); + assert(buffer->len == 0); + + buffer_push(buffer, temp, len); + assert(buffer->head == 0); + assert(buffer->len == len); + + free(temp); + } + } else { /* SOCK_DGRAM */ + /* buffer wrapped */ + size_t len = buffer->len; + char temp[len]; + size_t bytes = 0, total = 0, dgram_size; + struct Buffer *newbuf = new_buffer(SOCK_DGRAM, buffer_size(buffer), EV_DEFAULT); + + if (temp != NULL && newbuf != NULL) { + do { + /* Read each datagram, one at a time, and populate in new buffer */ + bytes = buffer_pop(buffer, temp, sizeof(temp)); + dgram_size = ALIGN_16BIT(bytes) + HDR_LEN(buffer); + total += dgram_size; + assert(buffer->len == len - total); + buffer_push(newbuf, temp, bytes); + assert(newbuf->len == total); + } while (bytes != 0 && total < len); + + /* Copy the data across */ + memcpy(buffer->buffer, newbuf->buffer, len); + buffer->head = newbuf->head; + buffer->len = newbuf->len; + + free_buffer(newbuf); + } } - if (dst != NULL) - *dst = buffer->buffer; + if (dst != NULL) { + if (buffer->type == SOCK_DGRAM) { + *dst = buffer->buffer+sizeof(uint16_t); + } else { + *dst = buffer->buffer; + } + } return buffer->len; } @@ -255,7 +316,7 @@ buffer_push(struct Buffer *dst, const void *src, size_t len) { if (dst->len == 0) dst->head = 0; - if (buffer_size(dst) - dst->len < len) + if (buffer_size(dst) - dst->len < len + HDR_LEN(dst)) return 0; /* insufficient room */ size_t iov_len = setup_write_iov(dst, iov, len); @@ -277,18 +338,23 @@ buffer_push(struct Buffer *dst, const void *src, size_t len) { * returns the number of entries setup */ static size_t -setup_write_iov(const struct Buffer *buffer, struct iovec *iov, size_t len) { - size_t room = buffer_size(buffer) - buffer->len; +setup_write_iov(const struct Buffer *buffer, struct iovec iov[2], size_t len) { + size_t headroom = buffer->type == SOCK_DGRAM ? sizeof(uint16_t) : 0; + size_t room = buffer_size(buffer) - buffer->len - headroom; if (room == 0) /* trivial case: no room */ return 0; + if (buffer->type == SOCK_DGRAM + && room < len) /* Complete writes only for dgram */ + return 0; + size_t write_len = room; /* Allow caller to specify maximum length */ if (len != 0) write_len = MIN(room, len); - size_t start = (buffer->head + buffer->len) & buffer->size_mask; + size_t start = (buffer->head + buffer->len + headroom) & buffer->size_mask; if (start + write_len <= buffer_size(buffer)) { iov[0].iov_base = buffer->buffer + start; @@ -319,16 +385,25 @@ setup_write_iov(const struct Buffer *buffer, struct iovec *iov, size_t len) { } static size_t -setup_read_iov(const struct Buffer *buffer, struct iovec *iov, size_t len) { +setup_read_iov(const struct Buffer *buffer, struct iovec iov[2], size_t len) { if (buffer->len == 0) return 0; - size_t read_len = buffer->len; - if (len != 0) - read_len = MIN(len, buffer->len); + size_t read_len; + if (buffer->type == SOCK_DGRAM) { + assert(buffer->head % sizeof(uint16_t) == 0); + read_len = *(uint16_t *)&buffer->buffer[buffer->head]; + } else { + read_len = buffer->len; + if (len != 0) + read_len = MIN(len, buffer->len); + } - if (buffer->head + read_len <= buffer_size(buffer)) { - iov[0].iov_base = buffer->buffer + buffer->head; + size_t start = (buffer->head + HDR_LEN(buffer)) & buffer->size_mask; + size_t end = (start + read_len) & buffer->size_mask; + + if (! end || start < end) { + iov[0].iov_base = buffer->buffer + start; iov[0].iov_len = read_len; /* assert iov are within bounds, non-zero length and non-overlapping */ @@ -338,8 +413,8 @@ setup_read_iov(const struct Buffer *buffer, struct iovec *iov, size_t len) { return 1; } else { - iov[0].iov_base = buffer->buffer + buffer->head; - iov[0].iov_len = buffer_size(buffer) - buffer->head; + iov[0].iov_base = buffer->buffer + start; + iov[0].iov_len = buffer_size(buffer) - start; iov[1].iov_base = buffer->buffer; iov[1].iov_len = read_len - iov[0].iov_len; @@ -357,13 +432,28 @@ setup_read_iov(const struct Buffer *buffer, struct iovec *iov, size_t len) { static inline void advance_write_position(struct Buffer *buffer, size_t offset) { - buffer->len += offset; + if (buffer->type == SOCK_DGRAM) { + uint16_t *dgram_len = (uint16_t *)&buffer->buffer[ + (buffer->head + buffer->len) & buffer->size_mask]; + + *dgram_len = (uint16_t)offset; + buffer->len += sizeof(uint16_t) + ALIGN_16BIT(offset); + } else { + buffer->len += offset; + } buffer->rx_bytes += offset; } static inline void advance_read_position(struct Buffer *buffer, size_t offset) { - buffer->head = (buffer->head + offset) & buffer->size_mask; - buffer->len -= offset; + if (buffer->type == SOCK_DGRAM) { + buffer->head = (buffer->head + sizeof(uint16_t) + ALIGN_16BIT(offset)) + & buffer->size_mask; + + buffer->len -= sizeof(uint16_t) + ALIGN_16BIT(offset); + } else { + buffer->head = (buffer->head + offset) & buffer->size_mask; + buffer->len -= offset; + } buffer->tx_bytes += offset; } diff --git a/src/buffer.h b/src/buffer.h index 1b0a4671..5ce0076e 100644 --- a/src/buffer.h +++ b/src/buffer.h @@ -28,11 +28,14 @@ #include #include +#include #include +#define HDR_LEN(b) (((b)->type == SOCK_DGRAM) ? sizeof(uint16_t) : 0) struct Buffer { char *buffer; + int type ; /* STREAM or DGRAM */ size_t size_mask; /* bit mask for buffer size */ size_t head; /* index of first byte of content */ size_t len; /* size of content */ @@ -42,11 +45,13 @@ struct Buffer { size_t rx_bytes; }; -struct Buffer *new_buffer(size_t, struct ev_loop *); +struct Buffer *new_buffer(int, size_t, struct ev_loop *); void free_buffer(struct Buffer *); ssize_t buffer_recv(struct Buffer *, int, int, struct ev_loop *); +ssize_t buffer_recvmsg(struct Buffer *, int, struct msghdr *, int, struct ev_loop *); ssize_t buffer_send(struct Buffer *, int, int, struct ev_loop *); +ssize_t buffer_sendmsg(struct Buffer *, int, struct msghdr *, int, struct ev_loop *); ssize_t buffer_read(struct Buffer *, int); ssize_t buffer_write(struct Buffer *, int); ssize_t buffer_resize(struct Buffer *, size_t); diff --git a/src/config.c b/src/config.c index 4ff23127..fad1d647 100644 --- a/src/config.c +++ b/src/config.c @@ -32,6 +32,7 @@ #include "config.h" #include "logger.h" #include "connection.h" +#include "protocol.h" struct LoggerBuilder { @@ -378,7 +379,10 @@ accept_pidfile(struct Config *config, const char *pidfile) { static int end_listener_stanza(struct Config *config, struct Listener *listener) { - listener->accept_cb = &accept_connection; + if (listener->protocol && listener->protocol->sock_type == SOCK_STREAM) + listener->accept_cb = &accept_stream_connection; + if (listener->protocol && listener->protocol->sock_type == SOCK_DGRAM) + listener->accept_cb = &accept_dgram_connection; if (valid_listener(listener) <= 0) { err("Invalid listener"); diff --git a/src/connection.c b/src/connection.c index c10a1090..31d00f60 100644 --- a/src/connection.c +++ b/src/connection.c @@ -32,6 +32,9 @@ #include #include #include +#ifdef __APPLE__ +#define __APPLE_USE_RFC_3542 +#endif #include #include /* getaddrinfo */ #include /* close */ @@ -80,7 +83,7 @@ static void close_connection(struct Connection *, struct ev_loop *); static void close_client_socket(struct Connection *, struct ev_loop *); static void abort_connection(struct Connection *); static void close_server_socket(struct Connection *, struct ev_loop *); -static struct Connection *new_connection(struct ev_loop *); +static struct Connection *new_connection(int type, struct ev_loop *); static void log_connection(struct Connection *); static void log_bad_request(struct Connection *, const char *, size_t, int); static void free_connection(struct Connection *); @@ -99,8 +102,8 @@ init_connections() { * Returns 1 on success or 0 on error; */ int -accept_connection(struct Listener *listener, struct ev_loop *loop) { - struct Connection *con = new_connection(loop); +accept_stream_connection(struct Listener *listener, struct ev_loop *loop) { + struct Connection *con = new_connection(listener->protocol->sock_type, loop); if (con == NULL) { err("new_connection failed"); return 0; @@ -161,6 +164,132 @@ accept_connection(struct Listener *listener, struct ev_loop *loop) { return 1; } +/** + * Accept a new UDP incoming connection + * + * Returns 1 on success or 0 on error; + */ +int +accept_dgram_connection(struct Listener *listener, struct ev_loop *loop) { + struct Connection *con = new_connection(listener->protocol->sock_type, loop); + if (con == NULL) { + err("new_connection failed"); + return 0; + } + con->listener = listener_ref_get(listener); + + char cbuf[CMSG_SPACE(sizeof(struct in6_pktinfo))]; + + struct msghdr msg = { + .msg_name = &con->client.addr, + .msg_namelen = sizeof(con->client.addr), + .msg_control = cbuf, + .msg_controllen = sizeof(cbuf), + }; + + ssize_t bytes_received = buffer_recvmsg(con->client.buffer, listener->watcher.fd, &msg, 0, loop); + if (bytes_received < 0) { + int saved_errno = errno; + + warn("recvmsg failed: %s", strerror(errno)); + free_connection(con); + + errno = saved_errno; + return 0; + } + + con->client.addr_len = msg.msg_namelen; + + for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; + cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IP + && cmsg->cmsg_type == IP_PKTINFO) { + const struct in_pktinfo *pi = (void *)CMSG_DATA(cmsg); + struct sockaddr_in *local_addr = (void *)&con->client.local_addr; + local_addr->sin_family = AF_INET; + local_addr->sin_port = htons(address_port(listener->address)); + memcpy(&local_addr->sin_addr, &pi->ipi_spec_dst, sizeof(struct in_addr)); + con->client.local_addr_len = sizeof(struct sockaddr_in); + } else if (cmsg->cmsg_level == IPPROTO_IPV6 + && cmsg->cmsg_type == IPV6_PKTINFO) { + const struct in6_pktinfo *pi = (void *)CMSG_DATA(cmsg); + struct sockaddr_in6 *local_addr = (void *)&con->client.local_addr; + local_addr->sin6_family = AF_INET6; + local_addr->sin6_port = htons(address_port(listener->address)); + memcpy(&local_addr->sin6_addr, &pi->ipi6_addr, sizeof(struct in6_addr)); + con->client.local_addr_len = sizeof(struct sockaddr_in6); + } + } + + +#ifdef HAVE_ACCEPT4 + int sockfd = socket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0); +#else + int sockfd = socket(AF_INET, SOCK_DGRAM, 0); +#endif + if (sockfd < 0) { + int saved_errno = errno; + + warn("accept failed: %s", strerror(errno)); + free_connection(con); + + errno = saved_errno; + return 0; + } + + int on = 1; + int result = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + if (result < 0) { + err("setsockopt SO_REUSEADDR failed: %s", strerror(errno)); + close(sockfd); + return result; + } + + result = bind(sockfd, (struct sockaddr *)&con->client.local_addr, + con->client.local_addr_len); + if (result < 0) { + char address[INET6_ADDRSTRLEN + 8]; + err("bind %s failed: %s", + display_sockaddr(&con->client.local_addr, address, sizeof(address)), + strerror(errno)); + close(sockfd); + return result; + } + + result = connect(sockfd, + (struct sockaddr *)&con->client.addr, + con->client.addr_len); + if (result < 0) { + char address[INET6_ADDRSTRLEN + 8]; + err("connect %s failed: %s", + display_sockaddr(&con->client.addr, address, sizeof(address)), + strerror(errno)); + close(sockfd); + return result; + } + +#ifndef HAVE_ACCEPT4 + int flags = fcntl(sockfd, F_GETFL, 0); + fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); +#endif + + /* Avoiding type-punned pointer warning */ + struct ev_io *client_watcher = &con->client.watcher; + ev_io_init(client_watcher, connection_cb, sockfd, EV_READ); + con->client.watcher.data = con; + con->state = ACCEPTED; + con->established_timestamp = ev_now(loop); + + TAILQ_INSERT_HEAD(&connections, con, entries); + + ev_io_start(loop, client_watcher); + + /* Since this is a datagram socket, we should process data received */ + connection_cb(loop, client_watcher, EV_READ); + + return 1; +} + /* * Close and free all connections */ @@ -610,9 +739,9 @@ free_resolv_cb_data(struct resolv_cb_data *cb_data) { static void initiate_server_connect(struct Connection *con, struct ev_loop *loop) { #ifdef HAVE_ACCEPT4 - int sockfd = socket(con->server.addr.ss_family, SOCK_STREAM | SOCK_NONBLOCK, 0); + int sockfd = socket(con->server.addr.ss_family, con->listener->protocol->sock_type| SOCK_NONBLOCK, 0); #else - int sockfd = socket(con->server.addr.ss_family, SOCK_STREAM, 0); + int sockfd = socket(con->server.addr.ss_family, con->listener->protocol->sock_type, 0); #endif if (sockfd < 0) { char client[INET6_ADDRSTRLEN + 8]; @@ -790,7 +919,7 @@ close_connection(struct Connection *con, struct ev_loop *loop) { * Allocate and initialize a new connection */ static struct Connection * -new_connection(struct ev_loop *loop) { +new_connection(int type, struct ev_loop *loop) { struct Connection *con = calloc(1, sizeof(struct Connection)); if (con == NULL) return NULL; @@ -807,14 +936,15 @@ new_connection(struct ev_loop *loop) { con->header_len = 0; con->query_handle = NULL; con->use_proxy_header = 0; + con->type = type; - con->client.buffer = new_buffer(4096, loop); + con->client.buffer = new_buffer(con->type, 4096, loop); if (con->client.buffer == NULL) { free_connection(con); return NULL; } - con->server.buffer = new_buffer(4096, loop); + con->server.buffer = new_buffer(con->type, 4096, loop); if (con->server.buffer == NULL) { free_connection(con); return NULL; diff --git a/src/connection.h b/src/connection.h index 013f046f..b4f332a5 100644 --- a/src/connection.h +++ b/src/connection.h @@ -58,12 +58,14 @@ struct Connection { struct ResolvQuery *query_handle; ev_tstamp established_timestamp; int use_proxy_header; + int type; TAILQ_ENTRY(Connection) entries; }; void init_connections(); -int accept_connection(struct Listener *, struct ev_loop *); +int accept_stream_connection(struct Listener *, struct ev_loop *); +int accept_dgram_connection(struct Listener *, struct ev_loop *); void free_connections(struct ev_loop *); void print_connections(); diff --git a/src/dtls.c b/src/dtls.c new file mode 100644 index 00000000..1f56968c --- /dev/null +++ b/src/dtls.c @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +/* + * This is a minimal DTLS implementation intended only to parse the server name + * extension. This was created based primarily on Wireshark dissection of a + * DTLS handshake and RFC4366. + */ + +#include +#include /* malloc() */ +#include +#include /* strncpy() */ +#include +#include +#include "dtls.h" +#include "sni.h" +#include "protocol.h" +#include "logger.h" + +#define DTLS_HANDSHAKE_CONTENT_TYPE 0x16 +#define DTLS_VERSION_12_MAJOR 0xfe +#define DTLS_VERSION_12_MINOR 0xfd +#define DTLS_HEADER_LEN 13 +#define DTLS_HANDSHAKE_TYPE_CLIENT_HELLO 0x01 + +#ifndef MIN +#define MIN(X, Y) ((X) < (Y) ? (X) : (Y)) +#endif + +static int parse_dtls_header(const uint8_t*, size_t, char **); + +static const unsigned char dtls_alert[] = { + 0x21, /* DTLS Alert */ + 0xfe, 0xfd, /* DTLS version */ + 0x00, 0x02, /* Payload length */ + 0x02, 0x28, /* Fatal, handshake failure */ +}; + +const struct Protocol *const dtls_protocol = &(struct Protocol){ + .name = "dtls", + .default_port = 443, + .parse_packet = (int (*const)(const char *, size_t, char **))&parse_dtls_header, + .abort_message = dtls_alert, + .abort_message_len = sizeof(dtls_alert), + .sock_type = SOCK_DGRAM, +}; + + +/* Parse a DTLS packet for the Server Name Indication extension in the client + * hello handshake, returning the first servername found (pointer to static + * array) + * + * Returns: + * >=0 - length of the hostname and updates *hostname + * caller is responsible for freeing *hostname + * -1 - Incomplete request + * -2 - No Host header included in this request + * -3 - Invalid hostname pointer + * -4 - malloc failure + * < -4 - Invalid TLS client hello + */ +static int +parse_dtls_header(const uint8_t *data, size_t data_len, char **hostname) { + uint8_t dtls_content_type; + uint8_t dtls_version_major; + uint8_t dtls_version_minor; + size_t pos = DTLS_HEADER_LEN; + size_t len; + //const uint8_t *data = &input_data[2]; /* Skip UDP length */ + + if (hostname == NULL) + return -3; + + /* Check that our UDP payload is at least large enough for a DTLS header */ + if (data_len < DTLS_HEADER_LEN) + return -1; + + + dtls_content_type = data[0]; + if (dtls_content_type != DTLS_HANDSHAKE_CONTENT_TYPE) { + debug("Request did not begin with DTLS handshake."); + return -5; + } + + dtls_version_major = data[1]; + dtls_version_minor = data[2]; + if (dtls_version_major != DTLS_VERSION_12_MAJOR && + dtls_version_minor != DTLS_VERSION_12_MINOR) { + debug("Requested version of DTLS not supported"); + return -5; + } + + /* + * Skip epoch (2 bytes) and sequence number (6 bytes). + * We want the length of this packet. + */ + len = ((size_t)data[11] << 8) + + (size_t)data[12] + DTLS_HEADER_LEN; + data_len = MIN(data_len, len); + + /* Check we received entire DTLS record length */ + if (data_len < len) { + debug("Failed to receive entire packet: len %zu data_len %zu", len, data_len); + return -1; + } + + /* + * Handshake + */ + if (pos + 1 > data_len) { + debug("handshake"); + return -5; + } + if (data[pos] != DTLS_HANDSHAKE_TYPE_CLIENT_HELLO) { + debug("Not a client hello"); + return -5; + } + + /* + * Skip past the following records: + * + * Length 3 + * Message sequence 2 + * Fragment offset 3 + * Fragment length 3 + * Version 2 + * Random 32 + */ + pos += 46; + + /* Session ID */ + if (pos + 1 > data_len) { + debug("Session ID incorrect"); + return -5; + } + len = (size_t)data[pos]; + debug("session ID length: 0x%zx", len); + pos += 1 + len; + + /* Cookie length */ + pos += 1; + + /* Cipher Suites */ + if (pos + 2 > data_len) { + debug("cipher suites"); + return -5; + } + len = ((size_t)data[pos] << 8) + (size_t)data[pos + 1]; + pos += 2 + len; + + /* Compression Methods */ + if (pos + 1 > data_len) { + debug("compression methods"); + return -5; + } + len = (size_t)data[pos]; + pos += 1 + len; + + /* Extensions */ + if (pos + 2 > data_len) { + printf("extensions"); + return -5; + } + len = ((size_t)data[pos] << 8) + (size_t)data[pos + 1]; + pos += 2; + + if (pos + len > data_len) { + debug("length error"); + return -5; + } + return parse_extensions(data + pos, len, hostname); +} diff --git a/src/dtls.h b/src/dtls.h new file mode 100644 index 00000000..053fb190 --- /dev/null +++ b/src/dtls.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef DTLS_H +#define DTLS_H + +#include "protocol.h" + +extern const struct Protocol *const dtls_protocol; + +#endif diff --git a/src/http.c b/src/http.c index 7d87c658..876b4cb4 100644 --- a/src/http.c +++ b/src/http.c @@ -28,6 +28,7 @@ #include /* strncpy() */ #include /* strncasecmp() */ #include /* isblank(), isdigit() */ +#include /* SOCK_STREAM */ #include "http.h" #include "protocol.h" @@ -39,7 +40,7 @@ static int get_header(const char *, const char *, size_t, char **); static size_t next_header(const char **, size_t *); -static const char http_503[] = +static const unsigned char http_503[] = "HTTP/1.1 503 Service Temporarily Unavailable\r\n" "Content-Type: text/html\r\n" "Connection: close\r\n\r\n" @@ -51,6 +52,7 @@ const struct Protocol *const http_protocol = &(struct Protocol){ .parse_packet = &parse_http_header, .abort_message = http_503, .abort_message_len = sizeof(http_503) - 1, + .sock_type = SOCK_STREAM, }; /* diff --git a/src/listener.c b/src/listener.c index ce5956da..b4116871 100644 --- a/src/listener.c +++ b/src/listener.c @@ -37,6 +37,9 @@ #include #include #include +#ifdef __APPLE__ +#define __APPLE_USE_RFC_3542 +#endif #include #include #include @@ -47,6 +50,7 @@ #include "protocol.h" #include "tls.h" #include "http.h" +#include "dtls.h" static void close_listener(struct ev_loop *, struct Listener *); static void accept_cb(struct ev_loop *, struct ev_io *, int); @@ -281,10 +285,15 @@ accept_listener_table_name(struct Listener *listener, const char *table_name) { int accept_listener_protocol(struct Listener *listener, const char *protocol) { + /* TODO(dl): come up with a better protocol registration method */ if (strncasecmp(protocol, http_protocol->name, strlen(protocol)) == 0) listener->protocol = http_protocol; - else + else if (strncasecmp(protocol, dtls_protocol->name, strlen(protocol)) == 0) + listener->protocol = dtls_protocol; + else if (strncasecmp(protocol, tls_protocol->name, strlen(protocol)) == 0) listener->protocol = tls_protocol; + else + return 0; if (address_port(listener->address) == 0) address_set_port(listener->address, listener->protocol->default_port); @@ -481,7 +490,7 @@ valid_listener(const struct Listener *listener) { return 0; } - if (listener->protocol != tls_protocol && listener->protocol != http_protocol) { + if (listener->protocol == NULL) { err("Invalid protocol"); return 0; } @@ -509,9 +518,11 @@ init_listener(struct Listener *listener, const struct Table_head *tables, address_port(listener->address)); #ifdef HAVE_ACCEPT4 - int sockfd = socket(address_sa(listener->address)->sa_family, SOCK_STREAM | SOCK_NONBLOCK, 0); + int sockfd = socket(address_sa(listener->address)->sa_family, + listener->protocol->sock_type | SOCK_NONBLOCK, 0); #else - int sockfd = socket(address_sa(listener->address)->sa_family, SOCK_STREAM, 0); + int sockfd = socket(address_sa(listener->address)->sa_family, + listener->protocol->sock_type, 0); #endif if (sockfd < 0) { err("socket failed: %s", strerror(errno)); @@ -572,7 +583,8 @@ init_listener(struct Listener *listener, const struct Table_head *tables, if (result < 0 && errno == EACCES) { /* Retry using binder module */ close(sockfd); - sockfd = bind_socket(address_sa(listener->address), + sockfd = bind_socket(listener->protocol->sock_type, + address_sa(listener->address), address_sa_len(listener->address)); if (sockfd < 0) { err("binder failed to bind to %s", @@ -587,11 +599,31 @@ init_listener(struct Listener *listener, const struct Table_head *tables, return result; } - result = listen(sockfd, SOMAXCONN); - if (result < 0) { - err("listen failed: %s", strerror(errno)); - close(sockfd); - return result; + if (listener->protocol->sock_type == SOCK_STREAM) { + result = listen(sockfd, SOMAXCONN); + if (result < 0) { + err("listen failed: %s", strerror(errno)); + close(sockfd); + return result; + } + } else if (listener->protocol->sock_type == SOCK_DGRAM) { + /* For UDP arrange to receive the local socket address so we can bind + * to it and established a connected UDP socket */ + switch (address_sa(listener->address)->sa_family) { + case AF_INET: + result = setsockopt(sockfd, IPPROTO_IP, IP_PKTINFO, &on, sizeof(on)); + break; + case AF_INET6: + result = setsockopt(sockfd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &on, sizeof(on)); + break; + default: + result = 0; + } + if (result < 0) { + err("setsockopt IP_PKTINFO/IPV6_RECVPKTINFO failed: %s", strerror(errno)); + close(sockfd); + return result; + } } #ifndef HAVE_ACCEPT4 diff --git a/src/protocol.h b/src/protocol.h index 454c5fc0..4c3b212f 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -32,8 +32,9 @@ struct Protocol { const char *const name; const uint16_t default_port; int (*const parse_packet)(const char*, size_t, char **); - const char *const abort_message; + const unsigned char *const abort_message; const size_t abort_message_len; + const int sock_type; }; #endif diff --git a/src/sni.c b/src/sni.c new file mode 100644 index 00000000..45031599 --- /dev/null +++ b/src/sni.c @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * Copyright (c) 2011 and 2012, Dustin Lundquist + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +/* + * Minimal SNI extension processing, shared by TLS and DTLS. + */ +#include +#include /* malloc() */ +#include +#include /* strncpy() */ +#include +#include +#include "sni.h" +#include "protocol.h" +#include "logger.h" + +int +parse_extensions(const uint8_t *data, size_t data_len, char **hostname) { + size_t pos = 0; + size_t len; + + /* Parse each 4 bytes for the extension header */ + while (pos + 4 <= data_len) { + /* Extension Length */ + len = ((size_t)data[pos + 2] << 8) + + (size_t)data[pos + 3]; + + /* Check if it's a server name extension */ + if (data[pos] == 0x00 && data[pos + 1] == 0x00) { + /* There can be only one extension of each type, so we break + our state and move p to beinnging of the extension here */ + if (pos + 4 + len > data_len) + return -5; + return parse_server_name_extension(data + pos + 4, len, hostname); + } + pos += 4 + len; /* Advance to the next extension header */ + } + /* Check we ended where we expected to */ + if (pos != data_len) + return -5; + + return -2; +} + +int +parse_server_name_extension(const uint8_t *data, size_t data_len, + char **hostname) { + size_t pos = 2; /* skip server name list length */ + size_t len; + + while (pos + 3 < data_len) { + len = ((size_t)data[pos + 1] << 8) + + (size_t)data[pos + 2]; + + if (pos + 3 + len > data_len) + return -5; + + switch (data[pos]) { /* name type */ + case 0x00: /* host_name */ + *hostname = malloc(len + 1); + if (*hostname == NULL) { + err("malloc() failure"); + return -4; + } + + strncpy(*hostname, (const char *)(data + pos + 3), len); + + (*hostname)[len] = '\0'; + + return len; + default: + debug("Unknown server name extension name type: %" PRIu8, + data[pos]); + } + pos += 3 + len; + } + /* Check we ended where we expected to */ + if (pos != data_len) + return -5; + + return -2; +} diff --git a/src/sni.h b/src/sni.h new file mode 100644 index 00000000..cafc4d5a --- /dev/null +++ b/src/sni.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * Copyright (c) 2011 and 2012, Dustin Lundquist + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef SNI_H +#define SNI_H + +int parse_extensions(const uint8_t*, size_t, char **); +int parse_server_name_extension(const uint8_t*, size_t, char **); + +#endif diff --git a/src/tls.c b/src/tls.c index ca3218b9..437234d8 100644 --- a/src/tls.c +++ b/src/tls.c @@ -35,6 +35,7 @@ #include #include #include "tls.h" +#include "sni.h" #include "protocol.h" #include "logger.h" @@ -49,11 +50,9 @@ static int parse_tls_header(const uint8_t*, size_t, char **); -static int parse_extensions(const uint8_t*, size_t, char **); -static int parse_server_name_extension(const uint8_t*, size_t, char **); -static const char tls_alert[] = { +static const unsigned char tls_alert[] = { 0x15, /* TLS Alert */ 0x03, 0x01, /* TLS version */ 0x00, 0x02, /* Payload length */ @@ -65,7 +64,8 @@ const struct Protocol *const tls_protocol = &(struct Protocol){ .default_port = 443, .parse_packet = (int (*const)(const char *, size_t, char **))&parse_tls_header, .abort_message = tls_alert, - .abort_message_len = sizeof(tls_alert) + .abort_message_len = sizeof(tls_alert), + .sock_type = SOCK_STREAM, }; @@ -186,70 +186,3 @@ parse_tls_header(const uint8_t *data, size_t data_len, char **hostname) { return -5; return parse_extensions(data + pos, len, hostname); } - -static int -parse_extensions(const uint8_t *data, size_t data_len, char **hostname) { - size_t pos = 0; - size_t len; - - /* Parse each 4 bytes for the extension header */ - while (pos + 4 <= data_len) { - /* Extension Length */ - len = ((size_t)data[pos + 2] << 8) + - (size_t)data[pos + 3]; - - /* Check if it's a server name extension */ - if (data[pos] == 0x00 && data[pos + 1] == 0x00) { - /* There can be only one extension of each type, so we break - our state and move p to beinnging of the extension here */ - if (pos + 4 + len > data_len) - return -5; - return parse_server_name_extension(data + pos + 4, len, hostname); - } - pos += 4 + len; /* Advance to the next extension header */ - } - /* Check we ended where we expected to */ - if (pos != data_len) - return -5; - - return -2; -} - -static int -parse_server_name_extension(const uint8_t *data, size_t data_len, - char **hostname) { - size_t pos = 2; /* skip server name list length */ - size_t len; - - while (pos + 3 < data_len) { - len = ((size_t)data[pos + 1] << 8) + - (size_t)data[pos + 2]; - - if (pos + 3 + len > data_len) - return -5; - - switch (data[pos]) { /* name type */ - case 0x00: /* host_name */ - *hostname = malloc(len + 1); - if (*hostname == NULL) { - err("malloc() failure"); - return -4; - } - - strncpy(*hostname, (const char *)(data + pos + 3), len); - - (*hostname)[len] = '\0'; - - return len; - default: - debug("Unknown server name extension name type: %" PRIu8, - data[pos]); - } - pos += 3 + len; - } - /* Check we ended where we expected to */ - if (pos != data_len) - return -5; - - return -2; -} diff --git a/tests/.gitignore b/tests/.gitignore index d0de50c9..c7983f7d 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -7,6 +7,7 @@ http_test resolv_test table_test tls_test +dtls_test *.log *.trs *.pcap diff --git a/tests/Makefile.am b/tests/Makefile.am index 19a4c610..24d94f70 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -7,20 +7,17 @@ TESTS = address_test \ table_test \ http_test \ tls_test \ - binder_test + dtls_test TESTS += functional_test \ bad_request_test \ - bind_source_test \ connection_reset_test \ fallback_test \ fd_limit_test \ ipv6_v6only_test \ - proxy_header_test \ reload_test \ reuseport_test \ - slow_client_test \ - transparent_proxy_test + slow_client_test if DNS_ENABLED TESTS += config_test \ resolv_test \ @@ -29,8 +26,8 @@ endif check_PROGRAMS = http_test \ tls_test \ + dtls_test \ table_test \ - binder_test \ buffer_test \ cfg_tokenizer_test \ address_test \ @@ -42,6 +39,12 @@ http_test_SOURCES = http_test.c \ tls_test_SOURCES = tls_test.c \ ../src/tls.c \ + ../src/sni.c \ + ../src/logger.c + +dtls_test_SOURCES = dtls_test.c \ + ../src/dtls.c \ + ../src/sni.c \ ../src/logger.c binder_test_SOURCES = binder_test.c \ @@ -74,6 +77,8 @@ config_test_SOURCES = config_test.c \ ../src/resolv.c \ ../src/resolv.h \ ../src/tls.c \ + ../src/dtls.c \ + ../src/sni.c \ ../src/http.c config_test_LDADD = $(LIBEV_LIBS) $(LIBPCRE_LIBS) $(LIBUDNS_LIBS) diff --git a/tests/buffer_test.c b/tests/buffer_test.c index ced1acc1..a2ba7611 100644 --- a/tests/buffer_test.c +++ b/tests/buffer_test.c @@ -12,7 +12,7 @@ static void test1() { char output[sizeof(input)]; int len, i; - buffer = new_buffer(256, EV_DEFAULT); + buffer = new_buffer(SOCK_STREAM, 256, EV_DEFAULT); assert(buffer != NULL); len = buffer_push(buffer, input, sizeof(input)); @@ -51,7 +51,7 @@ static void test2() { char output[sizeof(input)]; int len, i = 0; - buffer = new_buffer(256, EV_DEFAULT); + buffer = new_buffer(SOCK_STREAM, 256, EV_DEFAULT); assert(buffer != NULL); while (i < 236) { @@ -100,7 +100,7 @@ static void test3() { char output[sizeof(input)]; int len, i; - buffer = new_buffer(256, EV_DEFAULT); + buffer = new_buffer(SOCK_STREAM, 256, EV_DEFAULT); assert(buffer != NULL); len = buffer_push(buffer, input, sizeof(input)); @@ -127,7 +127,7 @@ static void test4() { struct Buffer *buffer; int read_fd, write_fd; - buffer = new_buffer(4096, EV_DEFAULT); + buffer = new_buffer(SOCK_STREAM, 4096, EV_DEFAULT); read_fd = open("/dev/zero", O_RDONLY); if (read_fd < 0) { @@ -155,7 +155,124 @@ static void test_buffer_coalesce() { char output[sizeof(input)]; int len; - buffer = new_buffer(4096, EV_DEFAULT); + buffer = new_buffer(SOCK_STREAM, 4096, EV_DEFAULT); + len = buffer_push(buffer, input, sizeof(input)); + assert(len == sizeof(input)); + + len = buffer_pop(buffer, output, sizeof(output)); + assert(len == sizeof(output)); + assert(buffer_len(buffer) == 0); + assert(buffer->head != 0); + + len = buffer_coalesce(buffer, NULL); + assert(len == 0); +} + +static void test5_udp1() { + struct Buffer *buffer; + char input[] = "This is a UDP test."; + char output[sizeof(input)]; + int len, i; + + buffer = new_buffer(SOCK_DGRAM, 256, EV_DEFAULT); + assert(buffer != NULL); + + len = buffer_push(buffer, input, sizeof(input)); + assert(len == sizeof(input)); + + + len = buffer_peek(buffer, output, sizeof(output)); + assert(len == sizeof(input)); + + for (i = 0; i < len; i++) { + assert(input[i] == output[i]); + } + + /* second peek to ensure the first didn't permute the state of the buffer */ + len = buffer_peek(buffer, output, sizeof(output)); + assert(len == sizeof(input)); + + for (i = 0; i < len; i++) + assert(input[i] == output[i]); + + /* test pop */ + len = buffer_pop(buffer, output, sizeof(output)); + assert(len == sizeof(input)); + + for (i = 0; i < len; i++) + assert(input[i] == output[i]); + + len = buffer_pop(buffer, output, sizeof(output)); + assert(len == 0); + + free_buffer(buffer); +} + +static void test6_udp2() { + struct Buffer *buffer; + char input[] = "Testing wrap around behaviour."; + char output[sizeof(input)]; + int len, i = 0; + + buffer = new_buffer(SOCK_DGRAM, 256, EV_DEFAULT); + assert(buffer != NULL); + + while (i < 204) { + len = buffer_push(buffer, input, sizeof(input)); + assert(len == sizeof(input)); + + i += len; + } + + while (len) { + len = buffer_pop(buffer, output, sizeof(output)); + } + + len = buffer_push(buffer, input, sizeof(input)); + assert(len == sizeof(input)); + + + len = buffer_peek(buffer, output, sizeof(output)); + assert(len == sizeof(input)); + + for (i = 0; i < len; i++) { + assert(input[i] == output[i]); + } + + len = buffer_pop(buffer, output, sizeof(output)); + assert(len == sizeof(input)); + + for (i = 0; i < len; i++) + assert(input[i] == output[i]); + + len = buffer_push(buffer, input, sizeof(input)); + assert(len == sizeof(input)); + + + len = buffer_peek(buffer, output, sizeof(output)); + assert(len == sizeof(input)); + + for (i = 0; i < len; i++) + assert(input[i] == output[i]); + + free_buffer(buffer); +} + +static void test_buffer_coalesce_udp() { + struct Buffer *buffer; + char input[] = "Test buffer resizing."; + char output[sizeof(input)]; + int len; + + buffer = new_buffer(SOCK_DGRAM, 64, EV_DEFAULT); + len = buffer_push(buffer, input, sizeof(input)); + assert(len == sizeof(input)); + + len = buffer_pop(buffer, output, sizeof(output)); + assert(len == sizeof(output)); + assert(buffer_len(buffer) == 0); + assert(buffer->head != 0); + len = buffer_push(buffer, input, sizeof(input)); assert(len == sizeof(input)); @@ -178,4 +295,10 @@ int main() { test4(); test_buffer_coalesce(); + + test5_udp1(); + + test6_udp2(); + + test_buffer_coalesce_udp(); } diff --git a/tests/dtls_test.c b/tests/dtls_test.c new file mode 100644 index 00000000..8d90bbdb --- /dev/null +++ b/tests/dtls_test.c @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#include +#include +#include +#include +#include "dtls.h" + +struct test_packet { + const char *packet; + size_t len; + const char *hostname; +}; + +const char good_hostname_1[] = "nginx1.umbrella.com"; +const unsigned char good_data_1[] = { + // UDP payload length + //0x00, 0xdd, + // DTLS Record Layer + 0x16, // Content Type: Handshake + 0xfe, 0xfd, // Version: DTLS 1.2 + 0x00, 0x00, // Epoch + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Sequence Number + 0x00, 0xc8, // Length + // Handshake + 0x01, // Handshake Type: Client Hello + 0x00, 0x00, 0xbc, // Length + 0x00, 0x00, // Message sequence + 0x00, 0x00, 0x00, // Fragment offset + 0x00, 0x00, 0xbc, // Fragment length + 0xfe, 0xfd, // Version: DTLS 1.2 + 0x4a, 0x72, 0xfb, 0x78, // Unix Time + 0xbc, 0x96, // Random + 0x1e, 0xf3, 0x78, 0x01, 0xa3, 0xa8, + 0xcf, 0x84, 0x14, 0xe5, 0xec, 0x06, + 0xee, 0xdb, 0x09, 0xde, 0x27, 0x62, + 0x3c, 0xd2, 0xb8, 0x00, 0x5f, 0x14, + 0x8c, 0xfc, + 0x20, // Session ID Length + 0x51, 0x1b, 0xb8, 0xf9, 0x10, 0x79, // Session ID + 0x23, 0x2c, 0xbb, 0x88, 0x92, 0x7c, + 0xb8, 0x51, 0x70, 0x8e, 0x50, 0x02, + 0x25, 0xd9, 0x85, 0xf3, 0x49, 0xe3, + 0xdb, 0x63, 0xf7, 0x4a, 0x06, 0xcb, + 0x6a, 0x0e, + 0x00, // Cookie Length + 0x00, 0x02, // Cipher suite length + 0xc0, 0x2c, // Cipher suite + 0x01, // Compression method length + 0x00, // Compression method + 0x00, 0x70, // Extension length + 0x00, 0x00, // Sever Name + 0x00, 0x18, // Length + 0x00, 0x16, // Server Name List Length + 0x00, // Server name type: hostname + 0x00, 0x13, // Server name length + 0x6e, 0x67, 0x69, 0x6e, // Server name: nginx1.umbrella.com + 0x78, 0x31, 0x2e, 0x75, + 0x6d, 0x62, 0x72, 0x65, + 0x6c, 0x6c, 0x61, 0x2e, + 0x63, 0x6f, 0x6d, + 0x00, 0x0b, // Type: ec_point_formats + 0x00, 0x04, // Length + 0x03, // EC point format length + 0x00, // Uncompressed + 0x01, // ansiX962_compressed_prime + 0x02, // ansiX962_compressed_char2 + 0x00, 0x0a, // Type: Supported groups + 0x00, 0x0c, // Length + 0x00, 0x0a, // Supported groups list length + 0x00, 0x1d, // x25519 + 0x00, 0x17, // secp256r1 + 0x00, 0x1e, // x448 + 0x00, 0x19, // secp521r1 + 0x00, 0x18, // secp384r1 + 0x00, 0x16, // Type: encrypt_then_mac + 0x00, 0x00, // Length + 0x00, 0x17, // Type: extended_master_secret + 0x00, 0x00, // Length + 0x00, 0x0d, // Type: Signature algorithms + 0x00, 0x30, // Length + 0x00, 0x2e, // Hash Algorithms length + // ecdsa_secp256r1_sha256 + 0x04, // SHA256 + 0x03, // EDCSA + // ecdsa_secp384r1_sha384 + 0x05, // SHA384 + 0x03, // EDCSA + // ecdsa_secp521r1_sha512 + 0x06, // SHA512 + 0x03, // EDCSA + // ed25519 + 0x08, // unknown + 0x07, // unknown + // ed448 + 0x08, // unknown + 0x08, // unknown + // rsa_pss_pss_sha256 + 0x08, // unknown + 0x09, // unknown + // rsa_pss_pss_sha384 + 0x08, // unknown + 0x0a, // unknown + // rsa_pss_pss_sha512 + 0x08, // unknown + 0x0b, // unknown + // rsa_pss_pss_sha256 + 0x08, // unknown + 0x04, // unknown + // rsa_pss_rsae_sha384 + 0x08, // unknown + 0x05, // unknown + // rsa_pss_rsae_sha512 + 0x08, // unknown + 0x06, // unknown + // rsa_pkcs1_sha256 + 0x04, // SHA256 + 0x01, // RSA + // rsa_pkcs1_sha384 + 0x05, // SHA384 + 0x01, // RSA + // rsa_pkcs1_sha512 + 0x06, // SHA512 + 0x01, // RSA + // SHA224 EDCSA + 0x03, // SHA224 + 0x03, // EDCSA + // edcsa_sha1 + 0x02, // SHA1 + 0x03, // EDCSA + // SHA224 RSA + 0x03, // SHA224 + 0x01, // RSA + // rsa_pkcs1_sha1 + 0x02, // SHA1 + 0x01, // RSA + // SHA224 DSA + 0x03, // SHA224 + 0x02, // DSA + // SHA1 DSA + 0x02, // SHA1 + 0x02, // DSA + // SHA256 DSA + 0x04, // SHA256 + 0x02, // DSA + // SHA384 DSA + 0x05, // SHA384 + 0x02, // DSA + // SHA512 DSA + 0x06, // SHA512 + 0x02 // DSA +}; + +const unsigned char bad_data_1[] = { + 0x16, 0x03, 0x01, 0x00, 0x68, 0x01, 0x00, 0x00, + 0x64, 0x03, 0x01, 0x4e, 0x4e, 0xbe, 0xc2, 0xa1, + 0x21, 0xad, 0xbc, 0x28, 0x33, 0xca, 0xa1, 0xd6, + 0x6e, 0x57, 0xb9, 0x1f, 0x8c, 0x19, 0x0e, 0x44, + 0x16, 0x9e, 0x7d, 0x20, 0x35, 0x4b, 0x65, 0xb2, + 0xc0, 0xd5, 0xa8, 0x00, 0x00, 0x28, 0x00, 0x39, + 0x00, 0x38, 0x00, 0x35, 0x00, 0x16, 0x00, 0x13, + 0x00, 0x0a, 0x00, 0x33, 0x00, 0x32, 0x00, 0x2f, + 0x00, 0x05, 0x00, 0x04, 0x00, 0x15, 0x00, 0x12, + 0x00, 0x09, 0x00, 0x14, 0x00, 0x11, 0x00, 0x08, + 0x00, 0x06, 0x00, 0x03, 0x00, 0xff, 0x02, 0x01, + 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f, 0x64, 0x61, + 0x6c, 0x68, 0x6f, 0x73, 0x74 +}; + +const unsigned char bad_data_2[] = { + 0x16, 0x03, 0x01, 0x00, 0x68, 0x01, 0x00, 0x00, + 0x64, 0x03, 0x01, 0x4e, 0x4e, 0xbe, 0xc2, 0xa1, + 0x21, 0xad, 0xbc, 0x28, 0x33, 0xca, 0xa1, 0xd6, + 0x6e, 0x57, 0xb9, 0x1f, 0x8c, 0x19, 0x0e, 0x44, + 0x16, 0x9e, 0x7d, 0x20, 0x35, 0x4b, 0x65, 0xb2, + 0xc0, 0xd5, 0xa8, 0x00, 0x00, 0x28, 0x00, 0x39, + 0x00, 0x38, 0x00, 0x35, 0x00, 0x16, 0x00, 0x13, + 0x00, 0x0a, 0x00, 0x33, 0x00, 0x32, 0x00, 0x2f, + 0x00, 0x05, 0x00, 0x04, 0x00, 0x15, 0x00, 0x12, + 0x00, 0x09, 0x00, 0x14, 0x00, 0x11, 0x00, 0x08, + 0x00, 0x06, 0x00, 0x03, 0x00, 0xff, 0x02, 0x01, + 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x0e, 0x00 +}; + +const unsigned char bad_data_3[] = { + 0x16, 0x03, 0x01, 0x00 +}; + +static struct test_packet good[] = { + { (char *)good_data_1, sizeof(good_data_1), good_hostname_1 }, +}; + +static struct test_packet bad[] = { + { (char *)bad_data_1, sizeof(bad_data_1), NULL }, + { (char *)bad_data_2, sizeof(bad_data_2), NULL }, + { (char *)bad_data_3, sizeof(bad_data_3), NULL } +}; + +int main() { + unsigned int i; + int result; + char *hostname; + + for (i = 0; i < sizeof(good) / sizeof(struct test_packet); i++) { + hostname = NULL; + + printf("Testing packet of length %zu\n", good[i].len); + result = dtls_protocol->parse_packet(good[i].packet, good[i].len, &hostname); + + assert(result == 19); + + assert(NULL != hostname); + + assert(0 == strcmp(good[i].hostname, hostname)); + + free(hostname); + } + + result = dtls_protocol->parse_packet(good[0].packet, good[0].len, NULL); + assert(result == -3); + + for (i = 0; i < sizeof(bad) / sizeof(struct test_packet); i++) { + hostname = NULL; + + result = dtls_protocol->parse_packet(bad[i].packet, bad[i].len, &hostname); + + // parse failure or not "localhost" + if (bad[i].hostname != NULL) + assert(result < 0 || + hostname == NULL || + strcmp(bad[i].hostname, hostname) != 0); + + free(hostname); + } + + return 0; +} +