From 1f3564636de73fb74d79f3a50f6eda2c9b04992e Mon Sep 17 00:00:00 2001 From: TheTechromancer <20261699+TheTechromancer@users.noreply.github.com> Date: Sun, 10 Oct 2021 04:50:40 -0400 Subject: [PATCH] Shared threadpool (#1489) * do not store return values by default * pass callback with every submit() * flake8 fixes * shared threadpool between all modules * flake8 fixes * change taskName to help avoid conflicts * minor bug fixes + performance improvements * move threadpool tests to new file * better error handling in scan status checks --- sf.py | 2 + sfscan.py | 90 +++--- spiderfoot/__init__.py | 1 + spiderfoot/plugin.py | 216 ++------------ spiderfoot/threadpool.py | 271 ++++++++++++++++++ test/unit/spiderfoot/test_spiderfootplugin.py | 44 --- .../spiderfoot/test_spiderfootthreadpool.py | 84 ++++++ 7 files changed, 427 insertions(+), 281 deletions(-) create mode 100644 spiderfoot/threadpool.py create mode 100644 test/unit/spiderfoot/test_spiderfootthreadpool.py diff --git a/sf.py b/sf.py index 4e0e9164ed..22edd6a233 100755 --- a/sf.py +++ b/sf.py @@ -52,6 +52,7 @@ def main(): # be overridden from saved configuration settings stored in the DB. sfConfig = { '_debug': False, # Debug + '_maxthreads': 3, # Number of modules to run concurrently '__logging': True, # Logging in general '__outputfilter': None, # Event types to filter from modules' output '_useragent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:62.0) Gecko/20100101 Firefox/62.0', # User-Agent to use for HTTP requests @@ -71,6 +72,7 @@ def main(): sfOptdescs = { '_debug': "Enable debugging?", + '_maxthreads': "Max number of modules to run concurrently", '_useragent': "User-Agent string to use for HTTP requests. Prefix with an '@' to randomly select the User Agent from a file containing user agent strings for each request, e.g. @C:\\useragents.txt or @/home/bob/useragents.txt. Or supply a URL to load the list from there.", '_dnsserver': "Override the default resolver with another DNS server. For example, 8.8.8.8 is Google's open DNS server.", '_fetchtimeout': "Number of seconds before giving up on a HTTP request.", diff --git a/sfscan.py b/sfscan.py index a0d64dfc42..f2d2438b46 100644 --- a/sfscan.py +++ b/sfscan.py @@ -22,7 +22,7 @@ import dns.resolver from sflib import SpiderFoot -from spiderfoot import SpiderFootDb, SpiderFootEvent, SpiderFootPlugin, SpiderFootTarget, SpiderFootHelpers, logger +from spiderfoot import SpiderFootDb, SpiderFootEvent, SpiderFootPlugin, SpiderFootTarget, SpiderFootHelpers, SpiderFootThreadPool, logger def startSpiderFootScanner(loggingQueue, *args, **kwargs): @@ -210,6 +210,8 @@ def __init__(self, scanName, scanId, targetValue, targetType, moduleList, global self.__setStatus("INITIALIZING", time.time() * 1000, None) + self.__sharedThreadPool = SpiderFootThreadPool(threads=self.__config.get("_maxthreads", 3), name='sharedThreadPool') + # Used when module threading is enabled self.eventQueue = None @@ -255,19 +257,17 @@ def __setStatus(self, status, started=None, ended=None): self.__status = status self.__dbh.scanInstanceSet(self.__scanId, started, ended, status) - def __startScan(self, threaded=True): + def __startScan(self): """Start running a scan. - - Args: - threaded (bool): whether to thread modules """ aborted = False self.__setStatus("STARTING", time.time() * 1000, None) self.__sf.status(f"Scan [{self.__scanId}] for '{self.__target.targetValue}' initiated.") - if threaded: - self.eventQueue = queue.Queue() + self.eventQueue = queue.Queue() + + self.__sharedThreadPool.start() # moduleList = list of modules the user wants to run self.__sf.debug(f"Loading {len(self.__moduleList)} modules ...") @@ -305,6 +305,7 @@ def __startScan(self, threaded=True): mod.setScanId(self.__scanId) mod.setup(self.__sf, self.__modconfig[modName]) mod.setDbh(self.__dbh) + mod.setSharedThreadPool(self.__sharedThreadPool) except Exception: self.__sf.error(f"Module {modName} initialization failed: {traceback.format_exc()}") mod.errorState = True @@ -343,13 +344,12 @@ def __startScan(self, threaded=True): continue # Set up the outgoing event queue - if threaded: - try: - mod.outgoingEventQueue = self.eventQueue - mod.incomingEventQueue = queue.Queue() - except Exception as e: - self.__sf.error(f"Module {modName} event queue setup failed: {e}") - continue + try: + mod.outgoingEventQueue = self.eventQueue + mod.incomingEventQueue = queue.Queue() + except Exception as e: + self.__sf.error(f"Module {modName} event queue setup failed: {e}") + continue self.__moduleInstances[modName] = mod self.__sf.status(f"{modName} module loaded.") @@ -365,19 +365,6 @@ def __startScan(self, threaded=True): # sort modules by priority self.__moduleInstances = OrderedDict(sorted(self.__moduleInstances.items(), key=lambda m: m[-1]._priority)) - if not threaded: - # Register listener modules and then start all modules sequentially - for module in list(self.__moduleInstances.values()): - for listenerModule in list(self.__moduleInstances.values()): - # Careful not to register twice or you will get duplicate events - if listenerModule in module._listenerModules: - continue - # Note the absence of a check for whether a module can register - # to itself. That is intentional because some modules will - # act on their own notifications (e.g. sfp_dns)! - if listenerModule.watchedEvents() is not None: - module.registerListener(listenerModule) - # Now we are ready to roll.. self.__setStatus("RUNNING") @@ -387,13 +374,8 @@ def __startScan(self, threaded=True): psMod.setTarget(self.__target) psMod.setDbh(self.__dbh) psMod.clearListeners() - if threaded: - psMod.outgoingEventQueue = self.eventQueue - psMod.incomingEventQueue = queue.Queue() - else: - for mod in list(self.__moduleInstances.values()): - if mod.watchedEvents() is not None: - psMod.registerListener(mod) + psMod.outgoingEventQueue = self.eventQueue + psMod.incomingEventQueue = queue.Queue() # Create the "ROOT" event which un-triggered modules will link events to rootEvent = SpiderFootEvent("ROOT", self.__targetValue, "", None) @@ -422,13 +404,9 @@ def __startScan(self, threaded=True): break # start threads - if threaded and not aborted: + if not aborted: self.waitForThreads() - if not threaded: - for mod in list(self.__moduleInstances.values()): - mod.finish() - if aborted: self.__sf.status(f"Scan [{self.__scanId}] aborted.") self.__setStatus("ABORTED", None, time.time() * 1000) @@ -514,18 +492,40 @@ def waitForThreads(self): # tell the modules to stop for mod in self.__moduleInstances.values(): mod._stopScanning = True + self.__sharedThreadPool.shutdown(wait=True) def threadsFinished(self, log_status=False): if self.eventQueue is None: return True - modules_waiting = { - m.__name__: m.incomingEventQueue.qsize() for m in - self.__moduleInstances.values() if m.incomingEventQueue is not None - } + modules_waiting = dict() + for m in self.__moduleInstances.values(): + try: + if m.incomingEventQueue is not None: + modules_waiting[m.__name__] = m.incomingEventQueue.qsize() + except Exception: + with suppress(Exception): + m.errorState = True modules_waiting = sorted(modules_waiting.items(), key=lambda x: x[-1], reverse=True) - modules_running = [m.__name__ for m in self.__moduleInstances.values() if m.running] - modules_errored = [m.__name__ for m in self.__moduleInstances.values() if m.errorState] + + modules_running = [] + for m in self.__moduleInstances.values(): + try: + if m.running: + modules_running.append(m.__name__) + except Exception: + with suppress(Exception): + m.errorState = True + + modules_errored = [] + for m in self.__moduleInstances.values(): + try: + if m.errorState: + modules_errored.append(m.__name__) + except Exception: + with suppress(Exception): + m.errorState = True + queues_empty = [qsize == 0 for m, qsize in modules_waiting] for mod in self.__moduleInstances.values(): diff --git a/spiderfoot/__init__.py b/spiderfoot/__init__.py index 5b855c8c92..b235469bfb 100644 --- a/spiderfoot/__init__.py +++ b/spiderfoot/__init__.py @@ -1,5 +1,6 @@ from .db import SpiderFootDb from .event import SpiderFootEvent +from .threadpool import SpiderFootThreadPool from .plugin import SpiderFootPlugin from .target import SpiderFootTarget from .helpers import SpiderFootHelpers diff --git a/spiderfoot/plugin.py b/spiderfoot/plugin.py index b160181eea..a53c512443 100644 --- a/spiderfoot/plugin.py +++ b/spiderfoot/plugin.py @@ -8,6 +8,8 @@ from time import sleep import traceback +from .threadpool import SpiderFootThreadPool + # begin logging overrides # these are copied from the python logging module # https://github.com/python/cpython/blob/main/Lib/logging/__init__.py @@ -127,14 +129,16 @@ class SpiderFootPlugin(): sf = None # Configuration, set in each module's setup() function opts = dict() + # Maximum threads + maxThreads = 1 def __init__(self): - # Whether the module is currently processing data - self._running = False # Holds the thread object when module threading is enabled self.thread = None # logging overrides self._log = None + # Shared thread pool for all modules + self.sharedThreadPool = None @property def log(self): @@ -432,7 +436,7 @@ def running(self): Returns: bool """ - return self._running + return self.sharedThreadPool.countQueuedTasks(f"{self.__name__}_threadWorker") > 0 def watchedEvents(self): """What events is this module interested in for input. The format is a list @@ -510,14 +514,12 @@ def threadWorker(self): except queue.Empty: sleep(.3) continue - self._running = True if sfEvent == 'FINISHED': self.sf.debug(f"{self.__name__}.threadWorker() got \"FINISHED\" from incomingEventQueue.") - self.finish() + self.poolExecute(self.finish) else: self.sf.debug(f"{self.__name__}.threadWorker() got event, {sfEvent.eventType}, from incomingEventQueue.") - self.handleEvent(sfEvent) - self._running = False + self.poolExecute(self.handleEvent, sfEvent) except KeyboardInterrupt: self.sf.debug(f"Interrupted module {self.__name__}.") self._stopScanning = True @@ -538,195 +540,25 @@ def threadWorker(self): # if there are leftover objects in the queue, the scan will hang. self.incomingEventQueue = None - finally: - self._running = False + def poolExecute(self, callback, *args, **kwargs): + """Execute a callback with the given args. + If we're in a storage module, execute normally. + Otherwise, use the shared thread pool. - class ThreadPool: - """ - A spiderfoot-integrated threading pool - Each thread in the pool is spawned only once, and reused for best performance. - - Example 1: using map() - with self.threadPool(self.opts["_maxthreads"]) as pool: - # callback("a", "arg1", kwarg1="kwarg1"), callback("b", "arg1" ...) - for result in pool.map( - callback, - ["a", "b", "c", "d"], - args=("arg1",) - kwargs={kwarg1: "kwarg1"} - ): - yield result - - Example 2: using submit() - with self.threadPool(self.opts["_maxthreads"]) as pool: - pool.start(callback, "arg1", kwarg1="kwarg1") - # callback(a, "arg1", kwarg1="kwarg1"), callback(b, "arg1" ...) - pool.submit(a) - pool.submit(b) - for result in pool.shutdown(): - yield result + Args: + callback: function to call + args: args (passed through to callback) + kwargs: kwargs (passed through to callback) """ - - def __init__(self, sfp, threads=100, qsize=None, name=None): - if name is None: - name = "" - - self.sfp = sfp - self.threads = int(threads) - try: - self.qsize = int(qsize) - except (TypeError, ValueError): - self.qsize = self.threads * 5 - self.pool = [None] * self.threads - self.name = str(name) - self.inputThread = None - self.inputQueue = queue.Queue(self.qsize) - self.outputQueue = queue.Queue(self.qsize) - self.stop = False - - def start(self, callback, *args, **kwargs): - self.sfp.sf.debug(f'Starting thread pool "{self.name}" with {self.threads:,} threads') - for i in range(self.threads): - name = kwargs.get('name', 'worker') - t = ThreadPoolWorker(self.sfp, target=callback, args=args, kwargs=kwargs, - inputQueue=self.inputQueue, outputQueue=self.outputQueue, - name=f"{self.name}_{name}_{i + 1}") - t.start() - self.pool[i] = t - - def shutdown(self, wait=True): - self.sfp.sf.debug(f'Shutting down thread pool "{self.name}" with wait={wait}') - if wait: - while not self.finished and not self.sfp.checkForStop(): - yield from self.results - sleep(.1) - self.stop = True - for t in self.pool: - with suppress(Exception): - t.stop = True - # make sure input queue is empty - with suppress(Exception): - while 1: - self.inputQueue.get_nowait() - with suppress(Exception): - self.inputQueue.close() - yield from self.results - with suppress(Exception): - self.outputQueue.close() - - def submit(self, arg, wait=True): - self.inputQueue.put(arg) - - def map(self, callback, iterable, args=None, kwargs=None, name=""): # noqa: A003 - """ - Args: - iterable: each entry will be passed as the first argument to the function - callback: the function to thread - args: additional arguments to pass to callback function - kwargs: keyword arguments to pass to callback function - name: base name to use for all the threads - - Yields: - return values from completed callback function - """ - - if args is None: - args = tuple() - - if kwargs is None: - kwargs = dict() - - self.inputThread = threading.Thread(target=self.feedQueue, args=(iterable, self.inputQueue)) - self.inputThread.start() - - self.start(callback, *args, **kwargs) - yield from self.shutdown() - - @property - def results(self): - while 1: - try: - yield self.outputQueue.get_nowait() - except Exception: - break - - def feedQueue(self, iterable, q): - for i in iterable: - if self.stop: - break - while not self.stop: - try: - q.put_nowait(i) - break - except queue.Full: - sleep(.1) - continue - - @property - def finished(self): - if self.sfp.checkForStop(): - return True - else: - finishedThreads = [not t.busy for t in self.pool if t is not None] - try: - inputThreadAlive = self.inputThread.is_alive() - except AttributeError: - inputThreadAlive = False - return not inputThreadAlive and self.inputQueue.empty() and all(finishedThreads) - - def __enter__(self): - return self - - def __exit__(self, exception_type, exception_value, traceback): - self.shutdown() - # Make sure queues are empty before exiting - with suppress(Exception): - for q in (self.outputQueue, self.inputQueue): - while 1: - try: - q.get_nowait() - except queue.Empty: - break + if self.__name__.startswith('sfp__stor_'): + callback(*args, **kwargs) + else: + self.sharedThreadPool.submit(callback, *args, taskName=f"{self.__name__}_threadWorker", maxThreads=self.maxThreads, **kwargs) def threadPool(self, *args, **kwargs): - return self.ThreadPool(self, *args, **kwargs) - - -class ThreadPoolWorker(threading.Thread): - - def __init__(self, sfp, inputQueue, outputQueue, group=None, target=None, - name=None, args=None, kwargs=None, verbose=None): - if args is None: - args = tuple() - - if kwargs is None: - kwargs = dict() + return SpiderFootThreadPool(*args, **kwargs) - self.sfp = sfp - self.inputQueue = inputQueue - self.outputQueue = outputQueue - self.busy = False - self.stop = False - - super().__init__(group, target, name, args, kwargs) - - def run(self): - while not self.stop: - try: - entry = self.inputQueue.get_nowait() - self.busy = True - try: - result = self._target(entry, *self._args, **self._kwargs) - except Exception: - import traceback - self.sfp.sf.error(f'Error in thread worker {self.name}: {traceback.format_exc()}') - break - self.outputQueue.put(result) - except queue.Empty: - self.busy = False - # sleep briefly to save CPU - sleep(.1) - finally: - self.busy = False + def setSharedThreadPool(self, sharedThreadPool): + self.sharedThreadPool = sharedThreadPool # end of SpiderFootPlugin class diff --git a/spiderfoot/threadpool.py b/spiderfoot/threadpool.py new file mode 100644 index 0000000000..deb17d2603 --- /dev/null +++ b/spiderfoot/threadpool.py @@ -0,0 +1,271 @@ +import queue +import logging +import threading +from time import sleep +from contextlib import suppress + + +class SpiderFootThreadPool: + """ + Each thread in the pool is spawned only once, and reused for best performance. + + Example 1: using map() + with SpiderFootThreadPool(self.opts["_maxthreads"]) as pool: + # callback("a", "arg1"), callback("b", "arg1"), ... + for result in pool.map( + callback, + ["a", "b", "c", "d"], + "arg1", + taskName="sfp_testmodule" + saveResult=True + ): + yield result + + Example 2: using submit() + with SpiderFootThreadPool(self.opts["_maxthreads"]) as pool: + pool.start() + # callback("arg1"), callback("arg2") + pool.submit(callback, "arg1", taskName="sfp_testmodule", saveResult=True) + pool.submit(callback, "arg2", taskName="sfp_testmodule", saveResult=True) + for result in pool.shutdown()["sfp_testmodule"]: + yield result + """ + + def __init__(self, threads=100, qsize=10, name=None): + """Initialize the SpiderFootThreadPool class. + + Args: + threads: Max number of threads + qsize: Queue size + name: Name + """ + if name is None: + name = "" + + self.log = logging.getLogger(f"spiderfoot.{__name__}") + self.threads = int(threads) + self.qsize = int(qsize) + self.pool = [None] * self.threads + self.name = str(name) + self.inputThread = None + self.inputQueues = dict() + self.outputQueues = dict() + self._stop = False + self._lock = threading.Lock() + + def start(self): + self.log.debug(f'Starting thread pool "{self.name}" with {self.threads:,} threads') + for i in range(self.threads): + t = ThreadPoolWorker(pool=self, name=f"{self.name}_worker_{i + 1}") + t.start() + self.pool[i] = t + + @property + def stop(self): + return self._stop + + @stop.setter + def stop(self, val): + assert val in (True, False), "stop must be either True or False" + for t in self.pool: + with suppress(Exception): + t.stop = val + self._stop = val + + def shutdown(self, wait=True): + """Shut down the pool. + + Args: + wait (bool): Whether to wait for the pool to finish executing + + Returns: + results (dict): (unordered) results in the format: {"taskName": [returnvalue1, returnvalue2, ...]} + """ + results = dict() + self.log.debug(f'Shutting down thread pool "{self.name}" with wait={wait}') + if wait: + while not self.finished and not self.stop: + with self._lock: + outputQueues = list(self.outputQueues) + for taskName in outputQueues: + moduleResults = list(self.results(taskName)) + try: + results[taskName] += moduleResults + except KeyError: + results[taskName] = moduleResults + sleep(.1) + self.stop = True + # make sure input queues are empty + with self._lock: + inputQueues = list(self.inputQueues.values()) + for q in inputQueues: + with suppress(Exception): + while 1: + q.get_nowait() + with suppress(Exception): + q.close() + # make sure output queues are empty + with self._lock: + outputQueues = list(self.outputQueues.items()) + for taskName, q in outputQueues: + moduleResults = list(self.results(taskName)) + try: + results[taskName] += moduleResults + except KeyError: + results[taskName] = moduleResults + with suppress(Exception): + q.close() + return results + + def submit(self, callback, *args, **kwargs): + """Submit a function call to the pool. + The "taskName" and "maxThreads" arguments are optional. + + Args: + callback (function): callback function + *args: Passed through to callback + **kwargs: Passed through to callback, except for taskName and maxThreads + """ + taskName = kwargs.get('taskName', 'default') + maxThreads = kwargs.pop('maxThreads', 100) + # block if this module's thread limit has been reached + while self.countQueuedTasks(taskName) >= maxThreads: + sleep(.01) + continue + self.log.debug(f"Submitting function \"{callback.__name__}\" from module \"{taskName}\" to thread pool \"{self.name}\"") + self.inputQueue(taskName).put((callback, args, kwargs)) + + def countQueuedTasks(self, taskName): + """For the specified task, returns the number of queued function calls + plus the number of functions which are currently executing + + Args: + taskName (str): Name of task + + Returns: + the number of queued function calls plus the number of functions which are currently executing + """ + queuedTasks = 0 + with suppress(Exception): + queuedTasks += self.inputQueues[taskName].qsize() + runningTasks = 0 + for t in self.pool: + with suppress(Exception): + if t.taskName == taskName: + runningTasks += 1 + return queuedTasks + runningTasks + + def inputQueue(self, taskName="default"): + try: + return self.inputQueues[taskName] + except KeyError: + self.inputQueues[taskName] = queue.Queue(self.qsize) + return self.inputQueues[taskName] + + def outputQueue(self, taskName="default"): + try: + return self.outputQueues[taskName] + except KeyError: + self.outputQueues[taskName] = queue.Queue(self.qsize) + return self.outputQueues[taskName] + + def map(self, callback, iterable, *args, **kwargs): # noqa: A003 + """ + Args: + iterable: each entry will be passed as the first argument to the function + callback: the function to thread + args: additional arguments to pass to callback function + kwargs: keyword arguments to pass to callback function + + Yields: + return values from completed callback function + """ + taskName = kwargs.get("taskName", "default") + self.inputThread = threading.Thread(target=self.feedQueue, args=(callback, iterable, args, kwargs)) + self.inputThread.start() + self.start() + sleep(.1) + yield from self.results(taskName, wait=True) + + def results(self, taskName="default", wait=False): + while 1: + result = False + with suppress(Exception): + while 1: + yield self.outputQueue(taskName).get_nowait() + result = True + if self.countQueuedTasks(taskName) == 0 or not wait: + break + if not result: + # sleep briefly to save CPU + sleep(.1) + + def feedQueue(self, callback, iterable, args, kwargs): + for i in iterable: + if self.stop: + break + self.submit(callback, i, *args, **kwargs) + + @property + def finished(self): + if self.stop: + return True + else: + finishedThreads = [not t.busy for t in self.pool if t is not None] + try: + inputThreadAlive = self.inputThread.is_alive() + except AttributeError: + inputThreadAlive = False + inputQueuesEmpty = [q.empty() for q in self.inputQueues.values()] + return not inputThreadAlive and all(inputQueuesEmpty) and all(finishedThreads) + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + self.shutdown() + + +class ThreadPoolWorker(threading.Thread): + + def __init__(self, pool, name=None): + + self.log = logging.getLogger(f"spiderfoot.{__name__}") + self.pool = pool + self.taskName = "" # which module submitted the callback + self.busy = False + self.stop = False + + super().__init__(name=name) + + def run(self): + # Round-robin through each module's input queue + while not self.stop: + ran = False + with self.pool._lock: + inputQueues = list(self.pool.inputQueues.values()) + for q in inputQueues: + if self.stop: + break + try: + self.busy = True + callback, args, kwargs = q.get_nowait() + self.taskName = kwargs.pop("taskName", "default") + saveResult = kwargs.pop("saveResult", False) + try: + result = callback(*args, **kwargs) + ran = True + except Exception: # noqa: B902 + import traceback + self.log.error(f'Error in thread worker {self.name}: {traceback.format_exc()}') + break + if saveResult: + self.pool.outputQueue(self.taskName).put(result) + except queue.Empty: + self.busy = False + finally: + self.busy = False + self.taskName = "" + # sleep briefly to save CPU + if not ran: + sleep(.05) diff --git a/test/unit/spiderfoot/test_spiderfootplugin.py b/test/unit/spiderfoot/test_spiderfootplugin.py index 2bb8e706c1..2bc8856b29 100644 --- a/test/unit/spiderfoot/test_spiderfootplugin.py +++ b/test/unit/spiderfoot/test_spiderfootplugin.py @@ -369,47 +369,3 @@ def test_start(self): sfp.sf = sf sfp.start() - - def test_threadPool(self): - """ - Test ThreadPool(sfp, threads=10) - """ - sf = SpiderFoot(self.default_options) - sfp = SpiderFootPlugin() - sfp.sf = sf - threads = 10 - - def callback(x, *args, **kwargs): - return (x, args, list(kwargs.items())[0]) - - iterable = ["a", "b", "c"] - args = ("arg1",) - kwargs = {"kwarg1": "kwarg1"} - expectedOutput = [ - ("a", ("arg1",), ("kwarg1", "kwarg1")), - ("b", ("arg1",), ("kwarg1", "kwarg1")), - ("c", ("arg1",), ("kwarg1", "kwarg1")) - ] - # Example 1: using map() - with sfp.threadPool(threads) as pool: - map_results = sorted( - list(pool.map( - callback, - iterable, - args=args, - kwargs=kwargs - )), - key=lambda x: x[0] - ) - self.assertEqual(map_results, expectedOutput) - - # Example 2: using submit() - with sfp.threadPool(threads) as pool: - pool.start(callback, *args, **kwargs) - for i in iterable: - pool.submit(i) - submit_results = sorted( - list(pool.shutdown()), - key=lambda x: x[0] - ) - self.assertEqual(submit_results, expectedOutput) diff --git a/test/unit/spiderfoot/test_spiderfootthreadpool.py b/test/unit/spiderfoot/test_spiderfootthreadpool.py new file mode 100644 index 0000000000..728e61e56c --- /dev/null +++ b/test/unit/spiderfoot/test_spiderfootthreadpool.py @@ -0,0 +1,84 @@ +# test_spiderfootplugin.py +import pytest +import unittest + +from spiderfoot import SpiderFootThreadPool + + +@pytest.mark.usefixtures +class TestSpiderFootThreadPool(unittest.TestCase): + """ + Test SpiderFoot + """ + + def test_threadPool(self): + """ + Test ThreadPool(sfp, threads=10) + """ + threads = 10 + + def callback(x, *args, **kwargs): + return (x, args, list(kwargs.items())[0]) + + iterable = ["a", "b", "c"] + args = ("arg1",) + kwargs = {"kwarg1": "kwarg1"} + expectedOutput = [ + ("a", ("arg1",), ("kwarg1", "kwarg1")), + ("b", ("arg1",), ("kwarg1", "kwarg1")), + ("c", ("arg1",), ("kwarg1", "kwarg1")) + ] + # Example 1: using map() + with SpiderFootThreadPool(threads) as pool: + map_results = sorted( + list(pool.map( + callback, + iterable, + *args, + saveResult=True, + **kwargs + )), + key=lambda x: x[0] + ) + self.assertEqual(map_results, expectedOutput) + + # Example 2: using submit() + with SpiderFootThreadPool(threads) as pool: + pool.start() + for i in iterable: + pool.submit(callback, *((i,) + args), saveResult=True, **kwargs) + submit_results = sorted( + list(pool.shutdown()["default"]), + key=lambda x: x[0] + ) + self.assertEqual(submit_results, expectedOutput) + + # Example 3: using both + threads = 1 + iterable2 = ["d", "e", "f"] + expectedOutput2 = [ + ("d", ("arg1",), ("kwarg1", "kwarg1")), + ("e", ("arg1",), ("kwarg1", "kwarg1")), + ("f", ("arg1",), ("kwarg1", "kwarg1")) + ] + pool = SpiderFootThreadPool(threads) + pool.start() + for i in iterable2: + pool.submit(callback, *((i,) + args), taskName="submitTest", saveResult=True, **kwargs) + map_results = sorted( + list(pool.map( + callback, + iterable, + *args, + taskName="mapTest", + saveResult=True, + **kwargs + )), + key=lambda x: x[0] + ) + submit_results = sorted( + list(pool.shutdown()["submitTest"]), + key=lambda x: x[0] + ) + self.assertEqual(map_results, expectedOutput) + self.assertEqual(submit_results, expectedOutput2)