Skip to content

Commit

Permalink
Merge pull request #10 from The-Firefighters/updates
Browse files Browse the repository at this point in the history
huristic improvement + minor changes for erel
  • Loading branch information
Almog-David authored Jun 13, 2024
2 parents 0eb5688 + e749659 commit d5cc0e7
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

logger = logging.getLogger(__name__)

def spreading_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, flag=None) -> list:
def spreading_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, stop_condition=None) -> list:
"""
"Approximability of the Firefighter Problem - Computing Cuts over Time",
by Elliot Anshelevich, Deeparnab Chakrabarty, Ameya Hate, Chaitanya Swamy (2010)
Expand Down Expand Up @@ -117,7 +117,7 @@ def spreading_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, fl

can_spread = spread_virus(Graph, infected_nodes)

if flag is not None:
if stop_condition is not None:
if len(targets) == 0 or any(node in infected_nodes for node in targets):
clean_graph(Graph)
logger.info(f"Returning vaccination strategy: {vaccination_strategy}. The strategy saved the nodes: {saved_target_nodes}")
Expand Down Expand Up @@ -284,7 +284,7 @@ def non_spreading_dirlaynet_minbudget(Graph:nx.DiGraph, src:int, targets:list)->
logger.info(f"Returning minimum budget: {min_budget}")
return min_budget

def heuristic_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, spreading:bool, flag=None) -> list:
def heuristic_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, spreading=True, stop_condition=None) -> list:
"""
This heuristic approach is based on the local search problem.
We will select the best neighbor that saves the most nodes from targets.
Expand All @@ -307,7 +307,7 @@ def heuristic_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, sp
>>> G = nx.DiGraph()
>>> G.add_nodes_from([0, 1, 2, 3], status="vulnerable")
>>> G.add_edges_from([(0, 1), (0, 2), (1, 2), (1, 3)])
>>> heuristic_maxsave(G, 1, 0, [1, 2, 3], True)
>>> heuristic_maxsave(G, 1, 0, [1, 2, 3])
[(1, 1)]
"""
if budget < 1:
Expand All @@ -322,6 +322,7 @@ def heuristic_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, sp
infected_nodes = []
vaccinated_nodes = []
vaccination_strategy = []
saved_target_nodes = set()
can_spread = True
Graph.nodes[source]['status'] = Status.INFECTED.value
infected_nodes.append(source)
Expand All @@ -343,6 +344,7 @@ def heuristic_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, sp


if nodes_saved is not None:
saved_target_nodes.update(nodes_saved)
targets[:] = [element for element in targets if element not in nodes_saved]
logger.info(f"Updated list of targets: {targets}")

Expand All @@ -351,15 +353,19 @@ def heuristic_maxsave(Graph:nx.DiGraph, budget:int, source:int, targets:list, sp

can_spread = spread_virus(Graph, infected_nodes)

if flag is not None:
if stop_condition is not None:
if len(targets) == 0 or any(node in infected_nodes for node in targets):
logger.info(f"Returning vaccination strategy: {vaccination_strategy}")
logger.info(f"Returning vaccination strategy: {vaccination_strategy}. The strategy saved the nodes: {saved_target_nodes}")
return vaccination_strategy

time_step += 1

for node in targets:
if Graph.nodes[node]['status'] != Status.INFECTED.value:
saved_target_nodes.add(node)

logger.info(f"Returning vaccination strategy: {vaccination_strategy}")
return vaccination_strategy
logger.info(f"Returning vaccination strategy: {vaccination_strategy}. The strategy saved the nodes: {saved_target_nodes}")
return vaccination_strategy, saved_target_nodes

def heuristic_minbudget(Graph:nx.DiGraph, source:int, targets:list, spreading:bool)-> int:
"""
Expand Down
15 changes: 8 additions & 7 deletions networkz/algorithms/approximation/firefighter_problem/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def spread_virus(graph:nx.DiGraph, infected_nodes:list)->bool:
graph.nodes[neighbor]['status'] = Status.INFECTED.value
new_infected_nodes.append(neighbor)
logger.debug("SPREAD VIRUS: Node " + f'{neighbor}' + " has been infected from node " + f'{node}')
display_graph(graph)
#display_graph(graph)

infected_nodes.clear()
for node in new_infected_nodes:
infected_nodes.append(node)
Expand Down Expand Up @@ -376,7 +377,7 @@ def spread_vaccination(graph:nx.DiGraph, vaccinated_nodes:list)->None:
graph.nodes[neighbor]['status'] = Status.VACCINATED.value
new_vaccinated_nodes.append(neighbor)
logger.debug("SPREAD VACCINATION: Node " + f'{neighbor}' + " has been vaccinated from node " + f'{node}')
display_graph(graph)
#display_graph(graph)
vaccinated_nodes.clear()
for node in new_vaccinated_nodes:
vaccinated_nodes.append(node)
Expand All @@ -403,7 +404,7 @@ def vaccinate_node(graph:nx.DiGraph, node:int)->None:
"""
graph.nodes[node]['status'] = Status.DIRECTLY_VACCINATED.value
logger.info("Node " + f'{node}' + " has been directly vaccinated")
display_graph(graph)
#display_graph(graph)
return

def clean_graph(graph:nx.DiGraph)->None:
Expand Down Expand Up @@ -710,13 +711,13 @@ def find_best_neighbor(graph:nx.DiGraph, infected_nodes:list, targets:list)->int
if graph.nodes[node]['status'] == Status.VULNERABLE.value:
# for each node that is target, we will add only his nighbors that are target as well
neighbors_list = list(graph.neighbors(node))
target_neighbors = set()
vulnerable_neighbors = set()
for neighbor in neighbors_list:
if graph.nodes[neighbor]['status'] == Status.VULNERABLE.value:
target_neighbors.add(neighbor)
vulnerable_neighbors.add(neighbor)
if node in targets:
target_neighbors.add(node)
common_elements = set(target_neighbors) & set(targets)
vulnerable_neighbors.add(node)
common_elements = set(vulnerable_neighbors) & set(targets)
logger.info("node " + f'{node}' + " is saving the nodes " + str(common_elements))
if len(common_elements) > max_number:
best_node = node
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
The Paper -
Approximability of the Firefighter Problem Computing Cuts over Time
Paper Link -
https://github.com/The-Firefighters/networkz/blob/master/networkz/algorithms/approximation/firefighter_problem/Approximability_of_the_Firefighter_Problem.pdf
Authors -
Elliot Anshelevich
Deeparnab Chakrabarty
Ameya Hate
Chaitanya Swamy
Developers -
Yuval Bubnovsky
Almog David
Shaked Levi
"""

import pytest
import networkx as nx
import json
import random

from networkz.algorithms.approximation.firefighter_problem.Firefighter_Problem import heuristic_maxsave, spreading_maxsave
from networkz.algorithms.approximation.firefighter_problem.Utils import find_best_neighbor, parse_json_to_networkx, Status

with open("networkz/algorithms/tests/test_firefighter_problem/graphs.json", "r") as file:
json_data = json.load(file)
graphs = parse_json_to_networkx(json_data)

@pytest.mark.parametrize("graph_key, budget, source, targets", [
("RegularGraph_Graph-1", 1, -2, [1, 2, 3, 4, 5, 6]),
("RegularGraph_Graph-4", 1, 8, [1, 2, 4, 6, 7]),
("RegularGraph_Graph-6", 1, 10, [0, 2, 3, 5, 6, 7, 8, 9]),
("RegularGraph_Graph-8", 1, 17, [1, 7, 12, 14, 8, 3, 11, 2]),
("RegularGraph_Graph-3", 1, 6, [1, 3, 5]),
])
def test_source_not_in_graph(graph_key, budget, source, targets):
with pytest.raises(ValueError):
heuristic_maxsave(graphs[graph_key], budget, source, targets)

@pytest.mark.parametrize("graph_key, budget, source, targets", [
("RegularGraph_Graph-2", 1, 0, [1, 2, 3, 9, 5, 16]),
("RegularGraph_Graph-3", 1, 4, [1, 2, 3, 6, 7]),
("RegularGraph_Graph-6", 1, 3, [0, 2, 5, 6, 7, 8, 10]),
("RegularGraph_Graph-8", 1, 11, [1, 3, 12, 19, 8, 10, 4, 2]),
("RegularGraph_Graph-7", 1, 2, [1, 3, -1, 5]),
])
def test_target_not_in_graph(graph_key, budget, source, targets):
with pytest.raises(ValueError):
heuristic_maxsave(graphs[graph_key], budget, source, targets)

@pytest.mark.parametrize("graph_key, budget, source, targets", [
("RegularGraph_Graph-1", 1, 0, [1, 2, 3, 0, 4, 5, 6]),
("RegularGraph_Graph-3", 1, 1, [5, 1, 4]),
("RegularGraph_Graph-4", 1, 4, [1, 2, 3, 4, 5, 6, 7]),
("RegularGraph_Graph-6", 1, 0, [0, 3, 5, 6, 7, 8, 9]),
("RegularGraph_Graph-8", 1, 0, [13, 10, 8, 6, 5, 4, 3, 0, 1, 2]),
])
def test_source_is_target(graph_key, budget, source, targets):
with pytest.raises(ValueError):
heuristic_maxsave(graphs[graph_key], budget, source, targets)

@pytest.mark.parametrize("graph_key, budget, source, targets, expected_length", [
("RegularGraph_Graph-1", 1, 0, [1, 2, 3, 4, 5, 6], 2),
("Dirlay_Graph-5", 2, 0, [1, 2, 3, 4, 5, 6, 7, 8], 3),
])
def test_strategy_length(graph_key, budget, source, targets, expected_length):
graph = graphs[graph_key]
calculated_strategy = spreading_maxsave(graph, budget, source, targets)[0]
print(calculated_strategy)

assert len(calculated_strategy) == expected_length


@pytest.mark.parametrize("graph_key, budget, source, targets, expected_strategy", [
("RegularGraph_Graph-1", 1, 0, [1, 2, 3, 4, 5, 6], [(1, 1), (6, 2)]),
("Dirlay_Graph-5", 2, 0, [1, 2, 3, 4, 5, 6, 7, 8], [(5, 1), (2, 1)]),
])
def test_save_all_vertices(graph_key, budget, source, targets, expected_strategy):
graph = graphs[graph_key]
calculated_strategy = heuristic_maxsave(graph, budget, source, targets)[0]
print(calculated_strategy)

assert calculated_strategy == expected_strategy

@pytest.mark.parametrize("graph_key, budget, source, targets, expected_strategy", [
("RegularGraph_Graph-6", 2, 1, [3, 9, 0, 5, 6], [(2, 1)]),
("RegularGraph_Graph-4", 1, 0, [2, 6, 4], [(1, 1)]),
])
def test_save_subgroup_vertices(graph_key, budget, source, targets, expected_strategy):
graph = graphs[graph_key]
calculated_strategy = heuristic_maxsave(graph, budget, source, targets)[0]
print(calculated_strategy)

assert calculated_strategy == expected_strategy

def test_random_graph_comparison():
for i in range(10):
num_nodes = random.randint(2,100)
nodes = list(range(num_nodes+1))
num_edges = 1000
save_amount = random.randint(1,num_nodes)
targets = []
G = nx.DiGraph()

G.add_nodes_from(nodes, status=Status.VULNERABLE.value)
for _ in range(num_edges):
source = random.randint(0, num_nodes - 1)
target = random.randint(0, num_nodes - 1)
if source != target: # Ensure no self-loops
G.add_edge(source, target)
for node in range(save_amount):
probability = random.random()
if probability < 0.75 and node!=0:
targets.append(node)

print(targets)
spreading_answer = spreading_maxsave(G,1,0,targets)[1]
heuristic_answer = heuristic_maxsave(G,1,0,targets)[1]
print(spreading_answer)
print(heuristic_answer)

assert len(spreading_answer) <= len(heuristic_answer)

print("All tests have passed!")
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_save_subgroup_vertices_nodes_list(graph_key, budget, source, targets, e

assert calculated_nodes_saved_list == expected_nodes_saved_list

def random_graph_test():
def test_random_graph():
for i in range(10):
num_nodes = random.randint(2,100)
nodes = list(range(num_nodes+1))
Expand Down

0 comments on commit d5cc0e7

Please sign in to comment.