diff --git a/sulley/fuzz_logger.py b/sulley/fuzz_logger.py new file mode 100644 index 0000000..375ba03 --- /dev/null +++ b/sulley/fuzz_logger.py @@ -0,0 +1,79 @@ +import ifuzz_logger +import os +import errno + + +class FuzzLogger(ifuzz_logger.IFuzzLogger): + """ + IFuzzLogger that saves sent and received data to files within a directory. + + File format is: -(rx|tx)-.txt + """ + + def __init__(self, path): + """ + :param path: Directory in which to save fuzz data. + """ + self._path = path + self._current_id = '' + self._rx_count = 0 + self._tx_count = 0 + + # mkdir -p self._path + try: + os.makedirs(self._path) + except OSError as exc: + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + + def open_test_case(self, test_case_id): + """ + Open a test case - i.e., a fuzzing mutation. + + :param test_case_id: Test case name/number. Should be unique. + + :return: None + """ + self._current_id = str(test_case_id) + self._rx_count = 0 + self._tx_count = 0 + + def log_send(self, data): + """ + Records data as about to be sent to the target. + + :param data: Transmitted data + :type data: buffer + + :return: None + :rtype: None + """ + self._tx_count += 1 + + filename = "{0}-tx-{1}.txt".format(self._current_id, self._tx_count) + full_name = os.path.join(self._path, filename) + + # Write data in binary mode to avoid newline conversion + with open(full_name, "wb") as file_handle: + file_handle.write(data) + + def log_recv(self, data): + """ + Records data as having been received from the target. + + :param data: Received data. + :type data: buffer + + :return: None + :rtype: None + """ + self._rx_count += 1 + + filename = "{0}-rx-{1}.txt".format(self._current_id, self._tx_count) + full_name = os.path.join(self._path, filename) + + # Write data in binary mode to avoid newline conversion + with open(full_name, "wb") as file_handle: + file_handle.write(data) diff --git a/sulley/ifuzz_logger.py b/sulley/ifuzz_logger.py new file mode 100644 index 0000000..309c44a --- /dev/null +++ b/sulley/ifuzz_logger.py @@ -0,0 +1,45 @@ +import abc + + +class IFuzzLogger(object): + """ + Abstract class for logging fuzz data. Allows for logging approaches. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def open_test_case(self, test_case_id): + """ + Open a test case - i.e., a fuzzing mutation. + + :param test_case_id: Test case name/number. Should be unique. + + :return: None + """ + raise NotImplementedError + + @abc.abstractmethod + def log_send(self, data): + """ + Records data as about to be sent to the target. + + :param data: Transmitted data + :type data: buffer + + :return: None + :rtype: None + """ + raise NotImplementedError + + @abc.abstractmethod + def log_recv(self, data): + """ + Records data as having been received from the target. + + :param data: Received data. + :type data: buffer + + :return: None + :rtype: None + """ + raise NotImplementedError diff --git a/sulley/iserial_like.py b/sulley/iserial_like.py new file mode 100644 index 0000000..cf236fa --- /dev/null +++ b/sulley/iserial_like.py @@ -0,0 +1,52 @@ +import abc + + +class ISerialLike(object): + """ + A serial-like interface, based on the pySerial module, + the notable difference being that open() must always be called after the object is first created. + + Facilitates dependency injection in modules that use pySerial. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def close(self): + """ + Close connection to the target. + + :return: None + """ + raise NotImplementedError + + @abc.abstractmethod + def open(self): + """ + Opens connection to the target. Make sure to call close! + + :return: None + """ + raise NotImplementedError + + @abc.abstractmethod + def recv(self, max_bytes): + """ + Receive up to max_bytes data from the target. + + :param max_bytes: Maximum number of bytes to receive. + :type max_bytes: int + + :return: Received data. + """ + raise NotImplementedError + + @abc.abstractmethod + def send(self, data): + """ + Send data to the target. Only valid after calling open! + + :param data: Data to send. + + :return: Number of bytes actually sent. + """ + raise NotImplementedError diff --git a/sulley/itarget_connection.py b/sulley/itarget_connection.py new file mode 100644 index 0000000..e10b0fd --- /dev/null +++ b/sulley/itarget_connection.py @@ -0,0 +1,51 @@ +import abc + + +class ITargetConnection(object): + """ + Interface for connections to fuzzing targets. + Target connections may be opened and closed multiple times. You must open before using send/recv and close + afterwards. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def close(self): + """ + Close connection. + + :return: None + """ + raise NotImplementedError + + @abc.abstractmethod + def open(self): + """ + Opens connection to the target. Make sure to call close! + + :return: None + """ + raise NotImplementedError + + @abc.abstractmethod + def recv(self, max_bytes): + """ + Receive up to max_bytes data. + + :param max_bytes: Maximum number of bytes to receive. + :type max_bytes: int + + :return: Received data. bytes('') if no data is received. + """ + raise NotImplementedError + + @abc.abstractmethod + def send(self, data): + """ + Send data to the target. + + :param data: Data to send. + + :return: None + """ + raise NotImplementedError diff --git a/sulley/serial_connection.py b/sulley/serial_connection.py new file mode 100644 index 0000000..908acf0 --- /dev/null +++ b/sulley/serial_connection.py @@ -0,0 +1,140 @@ +import time +import itarget_connection +import iserial_like + + +class SerialConnection(itarget_connection.ITargetConnection): + """ + ITargetConnection implementation for generic serial ports. + Designed to utilize SerialConnectionLowLevel (see __init__). + + Since serial ports provide no default functionality for separating messages/packets, this class provides + several means: + - timeout: Return received bytes after timeout seconds. + - msg_separator_time: + Return received bytes after the wire is silent for a given time. + This is useful, e.g., for terminal protocols without a machine-readable delimiter. + A response may take a long time to send its information, and you know the message is done + when data stops coming. + - content_check: + A user-defined function takes the data received so far and checks for a packet. + The function should return 0 if the packet isn't finished yet, or n if a valid message of n + bytes has been received. Remaining bytes are stored for next call to recv(). + + Example: + def content_check_newline(data): + if data.find('\n') >= 0: + return data.find('\n') + else: + return 0 + If none of these methods are used, your connection may hang forever. + """ + + def __init__(self, connection, timeout=None, message_separator_time=None, content_checker=None): + """ + @type connection: iserial_like.ISerialLike + @param connection: Low level connection, e.g., SerialConnectionLowLevel. + @type timeout: float + @param timeout: For recv(). After timeout seconds from receive start, + recv() will return all received data, if any. + @type message_separator_time: float + @param message_separator_time: (Optional, def=None) + After message_separator_time seconds _without receiving any more data_, + recv() will return. + @type content_checker: function(str) -> int + @param content_checker: (Optional, def=None) User-defined function. + recv() will pass all bytes received so far to this method. + If the method returns n > 0, recv() will return n bytes. + If it returns 0, recv() will keep on reading. + """ + self._connection = connection + self._logger = None + self.timeout = timeout + self.message_separator_time = message_separator_time + self.content_checker = content_checker + + self._leftover_bytes = b'' + + def close(self): + """ + Close connection to the target. + + :return: None + """ + self._connection.close() + + def open(self): + """ + Opens connection to the target. Make sure to call close! + + :return: None + """ + self._connection.open() + + def recv(self, max_bytes): + """ + Receive up to max_bytes data from the target. + + :param max_bytes: Maximum number of bytes to receive. + :type max_bytes: int + + :return: Received data. + """ + + self._connection.timeout = min(.001, self.message_separator_time, self.timeout) + + start_time = last_byte_time = time.time() + + data = self._leftover_bytes + self._leftover_bytes = b'' + + while len(data) < max_bytes: + # Update timer for message_separator_time + if len(data) > 0: + last_byte_time = time.time() + + # Try recv again + fragment = self._connection.recv(max_bytes=max_bytes-len(data)) + data += fragment + + # User-supplied content_checker function + if self.content_checker is not None: + num_valid_bytes = self.content_checker(data) + if num_valid_bytes > 0: + self._leftover_bytes = data[num_valid_bytes:] + return data[0:num_valid_bytes] + + # Check timeout and message_separator_time + cur_time = time.time() + if self.timeout is not None and cur_time - start_time >= self.timeout: + return data + if self.message_separator_time is not None and cur_time - last_byte_time >= self.message_separator_time: + return data + + return data + + def send(self, data): + """ + Send data to the target. Only valid after calling open! + + :param data: Data to send. + + :return: None + """ + bytes_sent = 0 + while bytes_sent < len(data): + bytes_sent_this_round = self._connection.send(data[bytes_sent:]) + if bytes_sent_this_round is not None: + bytes_sent += bytes_sent_this_round + return bytes_sent + + def set_logger(self, logger): + """ + Set this object's (and it's aggregated classes') logger. + + :param logger: Logger to use. + :type logger: logging.Logger + + :return: None + """ + self._logger = logger diff --git a/sulley/serial_connection_low_level.py b/sulley/serial_connection_low_level.py new file mode 100644 index 0000000..5fa1256 --- /dev/null +++ b/sulley/serial_connection_low_level.py @@ -0,0 +1,62 @@ +import iserial_like +import serial + + +class SerialConnectionLowLevel(iserial_like.ISerialLike): + """ + A basic wrapper for a serial object. + Separated from SerialConnection to allow for effective unit testing. + Implements serial_like.ISerialLike. + """ + + def __init__(self, port, baudrate, timeout=None): + """ + @type port: int | str + @param port: Serial port name or number. + @type baudrate: int + @param baudrate: Baud rate for port. + @type timeout: float + @param timeout: Serial port timeout. See pySerial docs. May be updated after creation. + """ + self._device = None + self.port = port + self.baudrate = baudrate + self.timeout = timeout + + def close(self): + """ + Close connection to the target. + + :return: None + """ + self._device.close() + + def open(self): + """ + Opens connection to the target. Make sure to call close! + + :return: None + """ + self._device = serial.Serial(port=self.port, baudrate=self.baudrate) + + def recv(self, max_bytes): + """ + Receive up to max_bytes data from the target. + + :param max_bytes: Maximum number of bytes to receive. + :type max_bytes: int + + :return: Received data. + """ + self._device.timeout = self.timeout + return self._device.read(size=max_bytes) + + def send(self, data): + """ + Send data to the target. Only valid after calling open! + + :param data: Data to send. + + :return: Number of bytes actually sent. + """ + return self._device.write(data) diff --git a/sulley/serial_target.py b/sulley/serial_target.py new file mode 100644 index 0000000..d73c773 --- /dev/null +++ b/sulley/serial_target.py @@ -0,0 +1,55 @@ +import sessions +import serial_connection +import serial_connection_low_level + + +class SerialTarget(sessions.Target): + """ + Target class that uses a serial_connection.SerialConnection. Serial messages are assumed to be time-separated, + terminated by a separator string/regex, or both. + Encapsulates connection logic for the target. Inherits pedrpc connection logic from sessions.Target. + + Contains a logger which is configured by Session.add_target(). + """ + + def __init__(self, port=0, baudrate=9600, timeout=5, message_separator_time=0.300, content_checker=None): + """ + See serial_connection.SerialConnection for details on timeout, message_separator_time, and content_checker. + + @type port: int | str + @param port: Serial port name or number. + + @type baudrate: int + @param baudrate: Baud rate for port. + + @type timeout: float + @param timeout: For recv(). After timeout seconds from receive start, + recv() will return all received data, if any. + + @type message_separator_time: float + @param message_separator_time: (Optional, def=None) + After message_separator_time seconds _without receiving any more data_, + recv() will return. + + @type content_checker: function(str) -> int + @param content_checker: (Optional, def=None) User-defined function. + recv() will pass all bytes received so far to this method. + If the method returns n > 0, recv() will return n bytes. + If it returns 0, recv() will keep on reading. + """ + super(SerialTarget, self).__init__(host="", port=1) + + self._target_connection = serial_connection.SerialConnection( + connection=serial_connection_low_level.SerialConnectionLowLevel(port=port, baudrate=baudrate), + timeout=timeout, + message_separator_time=message_separator_time, + content_checker=content_checker + ) + + # set these manually once target is instantiated. + self.netmon = None + self.procmon = None + self.vmcontrol = None + self.netmon_options = {} + self.procmon_options = {} + self.vmcontrol_options = {} diff --git a/sulley/sessions.py b/sulley/sessions.py index 9ba31c8..4177442 100644 --- a/sulley/sessions.py +++ b/sulley/sessions.py @@ -2,18 +2,16 @@ import zlib import time import socket -import ssl import signal import cPickle import threading -import httplib import logging import blocks import pgraph import sex import primitives - -from helpers import get_max_udp_size +import socket_connection +import ifuzz_logger from tornado.wsgi import WSGIContainer from tornado.httpserver import HTTPServer @@ -24,18 +22,29 @@ class Target(object): """ Target descriptor container. + Encapsulates connection logic for the target, as well as pedrpc connection logic. + + Contains a logger which is configured by Session.add_target(). """ - def __init__(self, host, port): + def __init__(self, host, port, proto="tcp", bind=None, timeout=5.0): """ @type host: str @param host: Hostname or IP address of target system @type port: int @param port: Port of target service + @type proto: str + @kwarg proto: (Optional, def="tcp") Communication protocol ("tcp", "udp", "ssl") + @type bind: tuple (host, port) + @kwarg bind: (Optional, def=random) Socket bind address and port + @type timeout: float + @kwarg timeout: (Optional, def=5.0) Seconds to wait for a send/recv prior to timing out """ + self._logger = None + self._fuzz_data_logger = None - self.host = host - self.port = port + self._target_connection = socket_connection.SocketConnection( + host=host, port=port, proto=proto, bind=bind, timeout=timeout) # set these manually once target is instantiated. self.netmon = None @@ -45,6 +54,22 @@ def __init__(self, host, port): self.procmon_options = {} self.vmcontrol_options = {} + def close(self): + """ + Close connection to the target. + + :return: None + """ + self._target_connection.close() + + def open(self): + """ + Opens connection to the target. Make sure to call close! + + :return: None + """ + self._target_connection.open() + def pedrpc_connect(self): """ Pass specified target parameters to the PED-RPC server. @@ -79,6 +104,57 @@ def pedrpc_connect(self): for key in self.netmon_options.keys(): eval('self.netmon.set_%s(self.netmon_options["%s"])' % (key, key)) + def recv(self, max_bytes): + """ + Receive up to max_bytes data from the target. + + :param max_bytes: Maximum number of bytes to receive. + :type max_bytes: int + + :return: Received data. + """ + data = self._target_connection.recv(max_bytes=max_bytes) + + if self._fuzz_data_logger is not None: + self._fuzz_data_logger.log_recv(data) + + return data + + def send(self, data): + """ + Send data to the target. Only valid after calling open! + + :param data: Data to send. + + :return: None + """ + if self._fuzz_data_logger is not None: + self._fuzz_data_logger.log_send(data) + self._target_connection.send(data=data) + + def set_logger(self, logger): + """ + Set this object's (and it's aggregated classes') logger. + + :param logger: Logger to use. + :type logger: logging.Logger + + :return: None + """ + self._logger = logger + self._target_connection.set_logger(logger=logger) + + def set_fuzz_data_logger(self, fuzz_data_logger): + """ + Set this object's fuzz data logger -- for sent and received fuzz data. + + :param fuzz_data_logger: New logger. + :type fuzz_data_logger: ifuzz_logger.IFuzzLogger + + :return: None + """ + self._fuzz_data_logger = fuzz_data_logger + class Connection(pgraph.Edge): def __init__(self, src, dst, callback=None): @@ -109,8 +185,8 @@ def callback(session, node, edge, sock) class Session(pgraph.Graph): def __init__(self, session_filename=None, skip=0, sleep_time=1.0, log_level=logging.INFO, logfile=None, - logfile_level=logging.DEBUG, proto="tcp", bind=None, restart_interval=0, timeout=5.0, web_port=26000, - crash_threshold=3, restart_sleep_time=300): + logfile_level=logging.DEBUG, restart_interval=0, web_port=26000, crash_threshold=3, + restart_sleep_time=300, fuzz_data_logger=None): """ Extends pgraph.graph and provides a container for architecting protocol dialogs. @@ -126,25 +202,20 @@ def __init__(self, session_filename=None, skip=0, sleep_time=1.0, log_level=logg @kwarg logfile: (Optional, def=None) Name of log file @type logfile_level: int @kwarg logfile_level: (Optional, def=logger.INFO) Set the log level for the logfile - @type proto: str - @kwarg proto: (Optional, def="tcp") Communication protocol ("tcp", "udp", "ssl") - @type bind: tuple (host, port) - @kwarg bind: (Optional, def=random) Socket bind address and port - @type timeout: float - @kwarg timeout: (Optional, def=5.0) Seconds to wait for a send/recv prior to timing out @type restart_interval: int @kwarg restart_interval (Optional, def=0) Restart the target after n test cases, disable by setting to 0 @type crash_threshold: int @kwarg crash_threshold (Optional, def=3) Maximum number of crashes allowed before a node is exhaust @type restart_sleep_time: int - @kwarg restart_sleep_time: Optional, def=300) Time in seconds to sleep when target can't be restarted - @type web_port: int + @kwarg restart_sleep_time: (Optional, def=300) Time in seconds to sleep when target can't be restarted + @type web_port: int @kwarg web_port: (Optional, def=26000) Port for monitoring fuzzing campaign via a web browser + @type fuzz_data_logger: ifuzz_logger.IFuzzLogger + @kwarg fuzz_data_logger: (Optional, def=None) For saving data sent to and from the target. """ super(Session, self).__init__() - self.max_udp = get_max_udp_size() try: import signal @@ -156,14 +227,11 @@ def __init__(self, session_filename=None, skip=0, sleep_time=1.0, log_level=logg self.session_filename = session_filename self.skip = skip self.sleep_time = sleep_time - self.proto = proto.lower() - self.bind = bind - self.ssl = False self.restart_interval = restart_interval - self.timeout = timeout self.web_port = web_port self.crash_threshold = crash_threshold self.restart_sleep_time = restart_sleep_time + self._fuzz_data_logger = fuzz_data_logger # Initialize logger self.logger = logging.getLogger("Sulley_logger") @@ -192,19 +260,6 @@ def __init__(self, session_filename=None, skip=0, sleep_time=1.0, log_level=logg self.is_paused = False self.crashing_primitives = {} - if self.proto == "tcp": - self.proto = socket.SOCK_STREAM - - elif self.proto == "ssl": - self.proto = socket.SOCK_STREAM - self.ssl = True - - elif self.proto == "udp": - self.proto = socket.SOCK_DGRAM - - else: - raise sex.SullyRuntimeError("INVALID PROTOCOL SPECIFIED: %s" % self.proto) - # import settings if they exist. self.import_file() @@ -214,6 +269,7 @@ def __init__(self, session_filename=None, skip=0, sleep_time=1.0, log_level=logg self.root.name = "__ROOT_NODE__" self.root.label = self.root.name self.last_recv = None + self.last_send = None self.add_node(self.root) @@ -238,12 +294,14 @@ def add_target(self, target): """ Add a target to the session. Multiple targets can be added for parallel fuzzing. - @type target: session.target + @type target: Target @param target: Target to add to session """ # pass specified target parameters to the PED-RPC server. target.pedrpc_connect() + target.set_logger(logger=self.logger) + target.set_fuzz_data_logger(fuzz_data_logger=self._fuzz_data_logger) # add target to internal list. self.targets.append(target) @@ -325,9 +383,7 @@ def export_file(self): "skip": self.total_mutant_index, "sleep_time": self.sleep_time, "restart_sleep_time": self.restart_sleep_time, - "proto": self.proto, "restart_interval": self.restart_interval, - "timeout": self.timeout, "web_port": self.web_port, "crash_threshold": self.crash_threshold, "total_num_mutations": self.total_num_mutations, @@ -428,6 +484,8 @@ def error_handler(error, msg, error_target, error_sock=None): # if we don't need to skip the current test case. if self.total_mutant_index > self.skip: self.logger.info("fuzzing %d of %d" % (self.fuzz_node.mutant_index, num_mutations)) + if self._fuzz_data_logger is not None: + self._fuzz_data_logger.open_test_case(self.total_mutant_index) # attempt to complete a fuzz transmission. keep trying until we are successful, whenever a failure # occurs, restart the target. @@ -448,58 +506,32 @@ def error_handler(error, msg, error_target, error_sock=None): continue try: - # establish a connection to the target. - sock = socket.socket(socket.AF_INET, self.proto) - except Exception, e: - error_handler(e, "failed creating socket", target) - continue - - if self.bind: - try: - sock.bind(self.bind) - except Exception, e: - error_handler(e, "failed binding on socket", target, sock) - continue - - try: - sock.settimeout(self.timeout) - # Connect is needed only for TCP stream - if self.proto == socket.SOCK_STREAM: - sock.connect((target.host, target.port)) - except Exception, e: - error_handler(e, "failed connecting on socket", target, sock) + target.open() + except socket.error, e: + error_handler(e, "socket connection failed", target, target) continue - # if SSL is requested, then enable it. - if self.ssl: - try: - ssl_sock = ssl.wrap_socket(sock) - sock = httplib.FakeSocket(sock, ssl_sock) - except Exception, e: - error_handler(e, "failed ssl setup", target, sock) - continue - # if the user registered a pre-send function, pass it the sock and let it do the deed. try: - self.pre_send(sock) + self.pre_send(target) except Exception, e: - error_handler(e, "pre_send() failed", target, sock) + error_handler(e, "pre_send() failed", target, target) continue # send out valid requests for each node in the current path up to the node we are fuzzing. try: for e in path[:-1]: node = self.nodes[e.dst] - self.transmit(sock, node, e) + self.transmit(target, node, e) except Exception, e: - error_handler(e, "failed transmitting a node up the path", target, sock) - continue + error_handler(e, "failed transmitting a node up the path", target, target) + raise # now send the current node we are fuzzing. try: - self.transmit(sock, self.fuzz_node, edge) + self.transmit(target, self.fuzz_node, edge) except Exception, e: - error_handler(e, "failed transmitting fuzz node", target, sock) + error_handler(e, "failed transmitting fuzz node", target, target) continue # if we reach this point the send was successful for break out of the while(1). @@ -509,12 +541,12 @@ def error_handler(error, msg, error_target, error_sock=None): # We do this outside the try/except loop because if our fuzz causes a crash then the post_send() # will likely fail and we don't want to sit in an endless loop. try: - self.post_send(sock) + self.post_send(target) except Exception, e: - error_handler(e, "post_send() failed", target, sock) + error_handler(e, "post_send() failed", target, target) # done with the socket. - sock.close() + target.close() # delay in between test cases. self.logger.info("sleeping for %f seconds" % self.sleep_time) @@ -561,9 +593,7 @@ def import_file(self): self.session_filename = data["session_filename"] self.sleep_time = data["sleep_time"] self.restart_sleep_time = data["restart_sleep_time"] - self.proto = data["proto"] self.restart_interval = data["restart_interval"] - self.timeout = data["timeout"] self.web_port = data["web_port"] self.crash_threshold = data["crash_threshold"] self.total_num_mutations = data["total_num_mutations"] @@ -804,34 +834,24 @@ def transmit(self, sock, node, edge): # Try to send payload down-range try: - # TCP/SSL - if self.proto == socket.SOCK_STREAM: - sock.send(data) - # UDP - elif self.proto == socket.SOCK_DGRAM: - # TODO: this logic does not prevent duplicate test cases, need to address this in the future. - # If our data is over the max UDP size for this platform, truncate before sending - if len(data) > self.max_udp: - self.logger.debug("Too much data for UDP, truncating to %d bytes" % self.max_udp) - data = data[:self.max_udp] - - sock.sendto(data, (self.targets[0].host, self.targets[0].port)) - - self.logger.debug("Packet sent : " + repr(data)) + self.targets[0].send(data) + self.last_send = data + except socket.error, inst: + self.logger.error("Socket error on send: %s" % inst) + try: # Receive data # TODO: Remove magic number (10000) - self.last_recv = sock.recv(10000) - - except Exception, inst: - self.logger.error("Socket error, send: %s" % inst) + self.last_recv = self.targets[0].recv(10000) + except socket.error, inst: + self.logger.error("Socket error on receive: %s" % inst) # If we have data in our recv buffer if self.last_recv: self.logger.debug("received: [%d] %s" % (len(self.last_recv), repr(self.last_recv))) # Assume a crash? else: - self.logger.warning("Nothing received on socket.") + self.logger.warning("Nothing received from target.") # Increment individual crash count self.crashing_primitives[self.fuzz_node.mutant] = self.crashing_primitives.get(self.fuzz_node.mutant, 0) + 1 # Note crash information diff --git a/sulley/socket_connection.py b/sulley/socket_connection.py new file mode 100644 index 0000000..32b6fe1 --- /dev/null +++ b/sulley/socket_connection.py @@ -0,0 +1,127 @@ +import itarget_connection +import socket +import ssl +import httplib + +import sex +from helpers import get_max_udp_size + + +class SocketConnection(itarget_connection.ITargetConnection): + """ + ITargetConnection implementation using sockets. Supports UDP, TCP, SSL. + """ + def __init__(self, host, port, proto="tcp", bind=None, timeout=5.0): + """ + @type host: str + @param host: Hostname or IP address of target system + @type port: int + @param port: Port of target service + @type proto: str + @kwarg proto: (Optional, def="tcp") Communication protocol ("tcp", "udp", "ssl") + @type bind: tuple (host, port) + @kwarg bind: (Optional, def=random) Socket bind address and port + @type timeout: float + @kwarg timeout: (Optional, def=5.0) Seconds to wait for a send/recv prior to timing out + """ + self.max_udp = get_max_udp_size() + + self.host = host + self.port = port + self.bind = bind + self.ssl = False + self.timeout = timeout + self.proto = proto.lower() + self._sock = None + self.logger = None + + if self.proto == "tcp": + self.proto = socket.SOCK_STREAM + + elif self.proto == "ssl": + self.proto = socket.SOCK_STREAM + self.ssl = True + + elif self.proto == "udp": + self.proto = socket.SOCK_DGRAM + + else: + raise sex.SullyRuntimeError("INVALID PROTOCOL SPECIFIED: %s" % self.proto) + + def close(self): + """ + Close connection to the target. + + :return: None + """ + self._sock.close() + + def open(self): + """ + Opens connection to the target. Make sure to call close! + + :return: None + """ + self._sock = socket.socket(socket.AF_INET, self.proto) + + if self.bind: + self._sock.bind(self.bind) + + self._sock.settimeout(self.timeout) + + # Connect is needed only for TCP stream + if self.proto == socket.SOCK_STREAM: + self._sock.connect((self.host, self.port)) + + # if SSL is requested, then enable it. + if self.ssl: + ssl_sock = ssl.wrap_socket(self._sock) + self._sock = httplib.FakeSocket(self._sock, ssl_sock) + + def recv(self, max_bytes): + """ + Receive up to max_bytes data from the target. + + :param max_bytes: Maximum number of bytes to receive. + :type max_bytes: int + + :return: Received data. + """ + try: + return self._sock.recv(max_bytes) + except socket.timeout: + return bytes('') + + def send(self, data): + """ + Send data to the target. Only valid after calling open! + + :param data: Data to send. + + :return: None + """ + # TCP/SSL + if self.proto == socket.SOCK_STREAM: + self._sock.send(data) + # UDP + elif self.proto == socket.SOCK_DGRAM: + # TODO: this logic does not prevent duplicate test cases, need to address this in the future. + # If our data is over the max UDP size for this platform, truncate before sending + if len(data) > self.max_udp: + self.logger.debug("Too much data for UDP, truncating to %d bytes" % self.max_udp) + data = data[:self.max_udp] + + self._sock.sendto(data, (self.host, self.port)) + + self.logger.debug("Packet sent : " + repr(data)) + + def set_logger(self, logger): + """ + Set this object's (and it's aggregated classes') logger. + + :param logger: Logger to use. + :type logger: logging.Logger + + :return: None + """ + self.logger = logger diff --git a/unit_tests/test_serial_connection_generic.py b/unit_tests/test_serial_connection_generic.py new file mode 100644 index 0000000..1f183b2 --- /dev/null +++ b/unit_tests/test_serial_connection_generic.py @@ -0,0 +1,488 @@ +import unittest +from sulley import iserial_like +from sulley.serial_connection import SerialConnection +import time + + +class MockSerial(iserial_like.ISerialLike): + """ + Mock ISerialLike class. + Methods include code for unit testing. See each method for details. + """ + + def __init__(self): + self.close_called = False + self.open_called = False + self.send_data_list = [] + self.send_return_queue = [] + self.recv_max_bytes_lengths = [] + self.recv_return_queue = [] + self.recv_return_nothing_by_default = False + self.recv_wait_times = [] + + def close(self): + """ + Close connection. + + :return: None + """ + self.close_called = True + + def open(self): + """ + Opens connection to the target. Make sure to call close! + + :return: None + """ + self.open_called = True + + def recv(self, max_bytes): + """ + Receive up to max_bytes data. + + Mock method: + - Waits some amount of time according to self.recv_wait_times + - Appends max_bytes to self.recv_max_bytes_lengths + - Returns based on self.recv_return_queue + * If empty, returns b'' if self.recv_return_nothing_by_default is True, or + b'0'*max_bytes otherwise. + + :param max_bytes: Maximum number of bytes to receive. + :type max_bytes: int + + :return: Received data. bytes('') if no data is received. + """ + # Wait if needed + if len(self.recv_wait_times) > 0: + time.sleep(self.recv_wait_times.pop(0)) + + # Save argument + self.recv_max_bytes_lengths.append(max_bytes) + + # Return data + if len(self.recv_return_queue) > 0: + return self.recv_return_queue.pop(0) + elif self.recv_return_nothing_by_default: + return b'' + else: + return b'0' * max_bytes + + def send(self, data): + """ + Send data to the target. + + :param data: Data to send. + + :return: None + """ + self.send_data_list.append(data) + if len(self.send_return_queue) > 0: + return self.send_return_queue.pop(0) + else: + return len(data) + + +class TestSerialConnection(unittest.TestCase): + def setUp(self): + self.mock = MockSerial() + + def test_open(self): + """ + Given: A SerialConnection using MockSerial. + When: Calling SerialConnection.open(). + Then: MockSerial.open() is called. + """ + uut = SerialConnection(connection=self.mock) + uut.open() + self.assertTrue(self.mock.open_called) + + def test_close(self): + """ + Given: A SerialConnection using MockSerial. + When: Calling SerialConnection.close(). + Then: MockSerial.close() is called. + """ + uut = SerialConnection(connection=self.mock) + uut.close() + self.assertTrue(self.mock.close_called) + + ########################################################################### + # Send tests + ########################################################################### + def test_send_basic(self): + """ + Given: A SerialConnection using MockSerial. + + When: Calling SerialConnection.send(data) + and: MockSerial.send() returns len(data). + + Then: Verify MockSerial.send() was called only once. + and: Verify MockSerial.send() received the expected data. + """ + uut = SerialConnection(connection=self.mock) + # When + data = b'ABCDEFG' + uut.send(data=data) + # Then + self.assertEqual(len(self.mock.send_data_list), 1) + self.assertEqual(self.mock.send_data_list[0], b'ABCDEFG') + + def test_send_return_none(self): + """ + Verify that MockSerial.send() is called again when it returns None. + + Given: A SerialConnection using MockSerial. + + When: Calling SerialConnection.send(data) with 10 bytes. + and: MockSerial.send() returns: None, 10. + + Then: Verify MockSerial.send() was called exactly 2 times. + and: Verify MockSerial.send() received the expected data each time. + """ + uut = SerialConnection(connection=self.mock) + # When + data = b'123456789A' + self.mock.send_return_queue = [None, 10] + uut.send(data=data) + # Then + self.assertEqual(len(self.mock.send_data_list), 2) + self.assertEqual(self.mock.send_data_list, [b'123456789A', + b'123456789A']) + + def test_send_multiple(self): + """ + Verify that MockSerial.send() is called repeatedly until it sends all the data. + + Given: A SerialConnection using MockSerial. + + When: Calling SerialConnection.send(data) with 9 bytes. + and: MockSerial.send() returns: 0, None, 0, 1, 2, 3, 2, 1. + + Then: Verify MockSerial.send() was called exactly 7 times. + and: Verify MockSerial.send() received the expected data each time. + """ + uut = SerialConnection(connection=self.mock) + # When + data = b'123456789' + self.mock.send_return_queue = [0, None, 0, 1, 2, 3, 2, 1] + uut.send(data=data) + # Then + self.assertEqual(len(self.mock.send_data_list), 8) + self.assertEqual(self.mock.send_data_list, [b'123456789', + b'123456789', + b'123456789', + b'123456789', + b'23456789', + b'456789', + b'789', + b'9']) + + def test_send_off_by_one(self): + """ + Verify that MockSerial.send() is called again when it sends all but 1 byte. + + Given: A SerialConnection using MockSerial. + + When: Calling SerialConnection.send(data) with 9 bytes. + and: MockSerial.send() returns: 8, 1. + + Then: Verify MockSerial.send() was called exactly 2 times. + and: Verify MockSerial.send() received the expected data each time. + """ + uut = SerialConnection(connection=self.mock) + # When + data = b'123456789' + self.mock.send_return_queue = [8, 1] + uut.send(data=data) + # Then + self.assertEqual(len(self.mock.send_data_list), 2) + self.assertEqual(self.mock.send_data_list, [b'123456789', + b'9']) + + def test_send_one_byte(self): + """ + Verify that MockSerial.send() is called again when it returns 0 after being given 1 byte. + + Given: A SerialConnection using MockSerial. + + When: Calling SerialConnection.send(data) with 1 byte. + and: MockSerial.send() returns: 0, 1. + + Then: Verify MockSerial.send() was called exactly 2 times. + and: Verify MockSerial.send() received the expected data each time. + """ + uut = SerialConnection(connection=self.mock) + # When + data = b'1' + self.mock.send_return_queue = [0, 1] + uut.send(data=data) + # Then + self.assertEqual(len(self.mock.send_data_list), 2) + self.assertEqual(self.mock.send_data_list, [b'1', + b'1']) + + def test_send_many(self): + """ + Verify that send works properly when MockSerial.send() sends 1 byte at a time. + + Given: A SerialConnection using MockSerial. + + When: Calling SerialConnection.send(data) with 9 bytes. + and: MockSerial.send() returns: 0, 500 times, followed by len(data). + + Then: Verify MockSerial.send() was called exactly 501 times. + and: Verify MockSerial.send() received the expected data each time. + """ + uut = SerialConnection(connection=self.mock) + # When + data = b'123456789' + self.mock.send_return_queue = [0] * 500 + [len(data)] + uut.send(data=data) + # Then + self.assertEqual(len(self.mock.send_data_list), 501) + self.assertEqual(self.mock.send_data_list, [b'123456789'] * 501) + + def test_send_zero_bytes(self): + """ + Verify that send() doesn't fail when given 0 bytes. + + Given: A SerialConnection using MockSerial. + + When: Calling SerialConnection.send(data) with 0 bytes. + and: MockSerial.send() set to return len(data). + + Then: Verify MockSerial.send() was called either 0 or 1 times. + and: Verify MockSerial.send() received 0 bytes, if anything. + """ + uut = SerialConnection(connection=self.mock) + # When + data = b'' + self.mock.send_return_queue = [0, 1] + uut.send(data=data) + # Then + self.assertLessEqual(len(self.mock.send_data_list), 1) + if len(self.mock.send_data_list) == 0: + self.assertEqual(self.mock.send_data_list, []) + else: + self.assertEqual(self.mock.send_data_list, [b'']) + + ########################################################################### + # Receive tests + ########################################################################### + def test_recv_simple(self): + """ + Verify that recv() works in the normal case. + + Given: A SerialConnection using MockSerial, + with no timeout/message_separator_time/content_checker. + + When: User calls SerialConnection.recv. + and: MockSerial.recv set to return data of length max_bytes. + + Then: SerialConnection calls MockSerial.recv exactly once. + and: SerialConnection.recv returns exactly what MockSerial.recv returned. + """ + uut = SerialConnection(connection=self.mock) + # When + self.mock.recv_return_queue = [b'0123456'] + data = uut.recv(max_bytes=7) + # Then + self.assertEqual(self.mock.recv_max_bytes_lengths, [7]) + self.assertEqual(data, b'0123456') + + def test_recv_max_bytes_only(self): + """ + Verify that recv() calls MockSerial.recv() repeatedly until it gets max_bytes of data. + + Given: A SerialConnection using MockSerial, + with no timeout/message_separator_time/content_checker. + + When: User calls SerialConnection.recv(10). + and: MockSerial.recv set to return 0, 0, 0, 1, 2, 3, 4 bytes. + + Then: SerialConnection calls MockSerial.recv exactly 7 times, + with max_bytes decreasing as appropriate. + and: SerialConnection.recv returns the concatenation of MockSerial.recv() return values. + """ + uut = SerialConnection(connection=self.mock) + # When + self.mock.recv_return_queue = [b'', b'', b'', b'1', b'22', b'123', b'1234'] + data = uut.recv(max_bytes=10) + # Then + self.assertEqual(self.mock.recv_max_bytes_lengths, [10, 10, 10, 10, 9, 7, 4]) + self.assertEqual(data, b'1221231234') + + def test_recv_timeout(self): + """ + Verify that recv() returns partial messages after the timeout expires. + + Given: A SerialConnection using MockSerial, + with timeout set to a smallish value. + + When: User calls SerialConnection.recv(n) several times with different values of n. + and: MockSerial.recv set to return a single message, then repeatedly return nothing. + + Then: SerialConnection.recv calls MockSerial.recv at least once. + and: SerialConnection.recv returns the MockSerial.recv() return value after the timeout. + + Note: Timeout functionality is tested, but not the precise timing. + """ + uut = SerialConnection(connection=self.mock, timeout=.001) # 1ms + + # n == 1 + self.mock.recv_return_nothing_by_default = True + self.mock.recv_return_queue = [b''] + data = uut.recv(max_bytes=1) + self.assertGreaterEqual(len(self.mock.recv_max_bytes_lengths), 1) + self.assertEqual(data, b'') + + # n == 2 + self.mock.recv_return_nothing_by_default = True + self.mock.recv_return_queue = [b'1'] + data = uut.recv(max_bytes=2) + self.assertGreaterEqual(len(self.mock.recv_max_bytes_lengths), 1) + self.assertEqual(data, b'1') + + # n == 3, len(data) == 1 + self.mock.recv_return_nothing_by_default = True + self.mock.recv_return_queue = [b'1'] + data = uut.recv(max_bytes=5) + self.assertGreaterEqual(len(self.mock.recv_max_bytes_lengths), 1) + self.assertEqual(data, b'1') + + # n == 3, len(data) == 2 + self.mock.recv_return_nothing_by_default = True + self.mock.recv_return_queue = [b'12'] + data = uut.recv(max_bytes=3) + self.assertGreaterEqual(len(self.mock.recv_max_bytes_lengths), 1) + self.assertEqual(data, b'12') + + # # n == 2**16, len(data) == 2**16 - 1 + # self.mock.recv_return_nothing_by_default = True + # self.mock.recv_return_queue = [b'\0'] * (2**16 - 1) + # data = uut.recv(max_bytes=2**16) + # self.assertGreaterEqual(len(self.mock.recv_max_bytes_lengths), 1) + # self.assertEqual(data, [b'\0'] * (2**16 - 1)) + + def test_recv_message_separator_time(self): + """ + Verify that message_separator_time works correctly. + Receive a message over time t, where t > message_separator_time, and each part of the message is delayed by + t' < message_separator_time. + + Given: A SerialConnection using MockSerial, + and: timeout set to 60ms. + and: message_separator_time set 20ms + + When: User calls SerialConnection.recv(60). + and: MockSerial.recv set to return increasing bytes. + and: MockSerial.recv set to delay 1ms on each call. + + Then: SerialConnection.recv calls MockSerial.recv more than 20 times. + and: SerialConnection.recv returns data with more than 20 bytes. + """ + # Given + uut = SerialConnection(connection=self.mock, timeout=.060, message_separator_time=.020) + + # When + self.mock.recv_return_queue = [b'1'] * 60 + self.mock.recv_wait_times = [.001] * 60 + data = uut.recv(max_bytes=60) + + # Then + self.assertGreater(len(self.mock.recv_max_bytes_lengths), 20) + self.assertGreater(len(data), 20) + + def test_recv_message_separator_time_2(self): + """ + Verify that message_separator_time works correctly. + Receive a message that times out with message_separator_time, but which would not time out with only a timeout. + + Given: A SerialConnection using MockSerial, + and: timeout set to 60ms. + and: message_separator_time set 20ms + + When: User calls SerialConnection.recv(60). + and: MockSerial.recv set to return 1 byte, then 1 byte, then 58 bytes. + and: MockSerial.recv set to delay 1ms, then 40ms, then 1ms. + + Then: SerialConnection.recv calls MockSerial.recv twice. + and: SerialConnection.recv returns only the first two bytes. + """ + # Given + uut = SerialConnection(connection=self.mock, timeout=.060, message_separator_time=.020) + + # When + self.mock.recv_return_queue = [b'1', b'2', b'3' * 58] + self.mock.recv_wait_times = [.001, .040, .001] + data = uut.recv(max_bytes=60) + + # Then + self.assertEqual(len(self.mock.recv_max_bytes_lengths), 2) + self.assertEqual(data, b'12') + + def test_recv_message_content_checker(self): + """ + Verify that content_checker is used correctly. + The content_checker indicates how much of a message is valid, if any. + Verify behavior when the content_checker consumes a part of the buffer, the full buffer, and then part of it + again. + + Given: A SerialConnection using MockSerial, + and: timeout set to 100ms. + and: message_separator_time set 20ms + and: content_checker set to a function that returns 0, 3, 0, 5, 0, 3 + + When: User calls SerialConnection.recv(100) 3 times. + and: MockSerial.recv set to return 2 bytes repeatedly. + + Then: SerialConnection.recv calls MockSerial.recv 6 times. + and: SerialConnection.recv returns only the first 3 bytes, then the next 5 bytes, then the next 3. + """ + # Given + # PyUnusedLocal suppression: args/kwargs make the method callable by SerialConnection, but are not used. + # noinspection PyUnusedLocal + def test_checker(*args, **kwargs): + """ + :param args: Ignored. Makes method callable with arguments. + :param kwargs: Ignored. Makes method callable with arguments. + + :return: 0, 3, 0, 5, 0, 3, 0, 0... + """ + if not hasattr(test_checker, "counter"): + test_checker.counter = 0 + + test_checker.counter += 1 + + if test_checker.counter == 2: + return 3 + elif test_checker.counter == 4: + return 5 + elif test_checker.counter == 6: + return 3 + else: + return 0 + + uut = SerialConnection(connection=self.mock, + timeout=.100, + message_separator_time=.020, + content_checker=test_checker) + + # When + self.mock.recv_return_queue = [b'12', b'34', b'56', b'78', b'9A', b'BC'] + + data = uut.recv(max_bytes=100) + self.assertEqual(len(self.mock.recv_max_bytes_lengths), 2) + self.assertEqual(data, b'123') + + data = uut.recv(max_bytes=100) + self.assertEqual(len(self.mock.recv_max_bytes_lengths), 4) + self.assertEqual(data, b'45678') + + data = uut.recv(max_bytes=100) + self.assertEqual(len(self.mock.recv_max_bytes_lengths), 6) + self.assertEqual(data, b'9AB') + + +if __name__ == '__main__': + unittest.main()