diff --git a/cornershot/cornershot.py b/cornershot/cornershot.py index df3c3b2..10fbd52 100644 --- a/cornershot/cornershot.py +++ b/cornershot/cornershot.py @@ -1,4 +1,3 @@ -import itertools import queue import threading import time @@ -33,7 +32,7 @@ def __init__(self, username, password, domain, workers=250, shots=None): self.resultQ = queue.Queue() self.runthreads = True self.results = {} - self.shot_gen = None + self.shot_list = [] self.total_shots = 0 def _takeashot(self): @@ -61,18 +60,30 @@ def add_shots(self, destinations, targets, target_ports=None, destination_ports= if target_ports is None: target_ports = TARGET_PORTS - if self.shot_gen: - self.shot_gen = itertools.chain(self.shot_gen, self._shots_generator(destinations, targets, target_ports, destination_ports)) - else: - self.shot_gen = self._shots_generator(destinations, targets, target_ports, destination_ports) + self._shots_generator(destinations, targets, target_ports, destination_ports) + + def add_many_shot_pairs(self, carrier_target_pairs, target_ports=None, destination_ports=None): + if target_ports is None: + target_ports = TARGET_PORTS + + tport_shot_class = [] + for target_port in target_ports: + tport_shot_class.append([target_port,self._get_suitable_shots(target_port, destination_ports)]) + + for ct_pair in carrier_target_pairs: + carrier = ct_pair[0] + target = ct_pair[1] + for tport_shot_class_pair in tport_shot_class: + target_port = tport_shot_class_pair[0] + for cls in tport_shot_class_pair[1]: + self.shot_list.append(cls(self.username, self.password, self.domain, carrier, target,target_port=target_port)) def _shots_generator(self, destinations, targets, target_ports, destination_ports=None): for destination in destinations: for target in targets: for target_port in target_ports: for cls in self._get_suitable_shots(target_port, destination_ports): - yield cls(self.username, self.password, self.domain, destination, target, - target_port=target_port) + self.shot_list.append(cls(self.username, self.password, self.domain, destination, target,target_port=target_port)) def _merge_result(self, dest, target, tport, state): if dest not in self.results: @@ -95,8 +106,9 @@ def _merge_result(self, dest, target, tport, state): def _shots_manager(self): remaining = MAX_QUEUE_SIZE while self.runthreads: - new_tasks = itertools.islice(self.shot_gen, remaining) - tasks = list(new_tasks) + new_tasks = self.shot_list[0:remaining] + self.shot_list = self.shot_list[remaining + 1:] + tasks = new_tasks shuffle(tasks) remaining = remaining - len(tasks) @@ -121,7 +133,6 @@ def _shots_manager(self): self.runthreads = False break - self.shot_gen = None self.total_shots = 0 def open_fire(self,blocking=True): @@ -138,15 +149,14 @@ def open_fire(self,blocking=True): main_thread = threading.Thread(target=self._shots_manager,daemon=True) main_thread.start() - def lock_and_load(self): - self.shot_gen, sg_sum = itertools.tee(self.shot_gen) - self.total_shots = sum(1 for _ in sg_sum) - def read_results(self): return self.results + def lock_and_load(self): + self.total_shots = self.remaining_shots() + def remaining_shots(self): - return self.total_shots + return len(self.shot_list) def _get_suitable_shots(self, target_port, destination_port): class_list = [] diff --git a/setup.py b/setup.py index b9b3c89..b2b41f1 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name='cornershot', python_requires='>=3', - version='0.2.0', + version='0.2.1', description='Library to test network connectivity', long_description_content_type='text/markdown', long_description=long_description,