Skip to content

Commit

Permalink
developing spectra 2_0
Browse files Browse the repository at this point in the history
  • Loading branch information
Yasha Ektefaie committed Dec 21, 2024
1 parent 7c02c4b commit 2aa1fd0
Show file tree
Hide file tree
Showing 5 changed files with 498 additions and 257 deletions.
3 changes: 2 additions & 1 deletion spectrae/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .spectra import Spectra
from .dataset import SpectraDataset
from .dataset import SpectraDataset
from .utils import Spectral_Property_Graph, FlattenedAdjacency
36 changes: 19 additions & 17 deletions spectrae/dataset.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
from abc import ABC, abstractmethod
from typing import List, Dict

class SpectraDataset(ABC):

def __init__(self, input_file, name):
self.input_file = input_file
self.name = name
self.samples = self.parse(input_file)

@abstractmethod
def sample_to_index(self, idx):
"""
Given a sample, return the data idx
"""
pass

self.sample_to_index = self.parse(input_file)
self.samples = list(self.sample_to_index.keys())
self.samples.sort()

@abstractmethod
def parse(self, input_file):
def parse(self, input_file: str) -> Dict:
"""
Given a dataset file, parse the dataset file.
Make sure there are only unique entries!
Given a dataset file, parse the dataset file to return a dictionary mapping a sample ID to the data
"""
pass
raise NotImplementedError("Must implement parse method to use SpectraDataset, see documentation for more information")

@abstractmethod
def __len__(self):
"""
Return the length of the dataset
"""
pass
return len(self.samples)

@abstractmethod
def __getitem__(self, idx):
"""
Given a dataset idx, return the element at that index
"""
pass
if isinstance(idx, int):
return self.sample_to_index[self.samples[idx]]
return self.sample_to_index[idx]

def index(self, value):
"""
Given a value, return the index of that value
"""
if value not in self.samples:
raise ValueError(f"{value} not in the dataset")
return self.samples.index(value)
150 changes: 79 additions & 71 deletions spectrae/independent_set_algo.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,91 @@
import random
import networkx as nx
from .utils import is_clique, connected_components, is_integer
from scipy import stats
import random
import numpy as np
from tqdm import tqdm
import torch
from .utils import FlattenedAdjacency, Spectral_Property_Graph, cross_split_overlap

def run_independent_set(spectral_parameter, input_G, seed = None,
debug=False, distribution = None, binary = True):
total_deleted = 0
independent_set = []

if seed is not None:
random.seed(seed)
def run_independent_set(spectral_parameter: int,
input_G: Spectral_Property_Graph,
seed: int = 42,
binary: bool = True,
minimum: int = None,
degree_choosing: bool = False,
num_splits: int = None):

G = input_G.copy()

if binary:
#First check if any connected component of the graph is a clique, if so, add it as one unit to the independent set
components = list(connected_components(G))
deleted = 0
for i, component in enumerate(components):
subgraph = G.subgraph(component)
if is_clique(subgraph):
print(f"Component {i} is too densly connected, adding samples as a single unit to independent set and deleting them from the graph")
independent_set.append(list(subgraph.nodes()))
G.remove_nodes_from(subgraph.nodes())
else:
for node in list(subgraph.nodes()):
if subgraph.degree(node) == len(subgraph.nodes()) - 1:
deleted += 1
G.remove_node(node)

print(f"Deleted {deleted} nodes from the graph since they were connected to all other nodes")

iterations = 0
total_num_deleted = 0
independent_set = []
random.seed(seed)

n = input_G.num_nodes()
indices_to_scan = list(range(n))
if spectral_parameter == 0:
return indices_to_scan
pbar = tqdm(total = len(indices_to_scan))

#Trying a non-percentile approach
#Note this assumes there are 20
if not binary:
if num_splits is None:
raise Exception("Num splits must be specified for non-binary graphs, see documentation for more information")
threshold = spectral_parameter*(torch.max(input_G) - torch.min(input_G))/num_splits
else:
threshold = 0
print(f"Threshold is {threshold}")
indices_deleted = []

expected_number_delete = int(n * spectral_parameter)
print(expected_number_delete)

while not nx.is_empty(G):
chosen_node = random.sample(list(G.nodes()), 1)[0]
while len(indices_to_scan) > 0:
print(len(indices_deleted))
indices_deleted = []
if degree_choosing:
chosen_node, _ = input_G.get_minimum_degree_node(indices_to_scan)
else:
chosen_node = random.sample(indices_to_scan, 1)[0]

indices_to_scan.remove(chosen_node)

to_iterate = indices_to_scan[:]

independent_set.append(chosen_node)
neighbors = G.neighbors(chosen_node)
neighbors_to_delete = []
indices_to_gather = []

for index in to_iterate:
indices_to_gather.append((chosen_node, index))

values = input_G.get_weights(indices_to_gather)

indices_deleted.extend(list(torch.tensor(to_iterate).cuda()[values > threshold].cpu().numpy()))

indices_deleted = list(set(indices_deleted))
indices_to_scan = set(indices_to_scan)

for neighbor in neighbors:
if not binary:
if spectral_parameter == 1.0:
neighbors_to_delete.append(neighbor)
else:
edge_weight = G[chosen_node][neighbor]['weight']
if distribution is None:
raise Exception("Distribution must be provided if binary is set to False, must precompute similarities")
if random.random() < spectral_parameter and (1-spectral_parameter)*100 < stats.percentileofscore(distribution, edge_weight):
neighbors_to_delete.append(neighbor)
else:
if spectral_parameter == 1.0:
neighbors_to_delete.append(neighbor)
elif spectral_parameter != 0.0:
if len(indices_deleted) > expected_number_delete:
indices_deleted = [chosen_node]
total_num_deleted += 1
else:
independent_set.append(chosen_node)
for i in indices_deleted:
if binary:
if random.random() < spectral_parameter:
neighbors_to_delete.append(neighbor)
indices_to_scan.remove(i)
total_num_deleted += 1
else:
indices_to_scan.remove(i)
total_num_deleted += 1

if minimum is not None:
if n - total_num_deleted <= minimum - len(independent_set):
independent_set.extend(indices_to_scan)
return independent_set

if debug:
print(f"Iteration {iterations} Stats")
print(f"Deleted {len(neighbors_to_delete)} nodes from {G.degree(chosen_node)} neighbors of node {chosen_node}")
total_deleted += len(neighbors_to_delete)

for neighbor in neighbors_to_delete:
G.remove_node(neighbor)

if chosen_node not in neighbors_to_delete:
G.remove_node(chosen_node)

iterations += 1

for node in list(G.nodes()):
#Append the nodes left to G
independent_set.append(node)
indices_deleted.append(chosen_node)
indices_to_scan = list(indices_to_scan)
pbar.update(len(indices_deleted))

if debug:
print(f"{len(input_G.nodes())} nodes in the original graph")
print(f"Total deleted {total_deleted}")
print(f"{len(independent_set)} nodes in the independent set")

pbar.close()

return independent_set
return independent_set
Loading

0 comments on commit 2aa1fd0

Please sign in to comment.