Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Yasha Ektefaie committed Dec 23, 2024
1 parent 2aa1fd0 commit b5a773d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 38 deletions.
2 changes: 1 addition & 1 deletion spectrae/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .spectra import Spectra
from .dataset import SpectraDataset
from .utils import Spectral_Property_Graph, FlattenedAdjacency
from .utils import Spectral_Property_Graph, FlattenedAdjacency, output_split_stats
5 changes: 3 additions & 2 deletions spectrae/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def __init__(self, input_file, name):
self.sample_to_index = self.parse(input_file)
self.samples = list(self.sample_to_index.keys())
self.samples.sort()
self.index_map = {value: idx for idx, value in enumerate(self.samples)}

@abstractmethod
def parse(self, input_file: str) -> Dict:
Expand All @@ -35,6 +36,6 @@ def index(self, value):
"""
Given a value, return the index of that value
"""
if value not in self.samples:
if value not in self.index_map:
raise ValueError(f"{value} not in the dataset")
return self.samples.index(value)
return self.index_map[value]
79 changes: 60 additions & 19 deletions spectrae/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,62 @@ def spectra_properties(self, sample_one, sample_two):

def cross_split_overlap(self,
split: List[int],
split_two: Optional[List[int]] = None) -> Tuple[float, float, float]:
split_two: Optional[List[int]] = None,
chunksize: int = 10000000,
show_progress: bool = False) -> Tuple[float, float, float]:

def calculate_overlap(index_to_gather):
if self.SPG.binary:
num_similar = sum(1 for i, j in index_to_gather if self.SPG.get_weight(i, j) > 0)
return num_similar / len(split), num_similar, len(split)
num_similar = 0

if show_progress:
index_to_gather = tqdm(index_to_gather, total = len(split))
else:
index_to_gather = index_to_gather

for compare_list in index_to_gather:
if self.SPG.get_weights(compare_list).sum() > 0:
num_similar += 1

return num_similar/(len(split)), num_similar, len(split)
else:
mean_val = 0.0
std_val = 0.0
max_val = float('-inf')
min_val = float('inf')
count = 0
for i, j in index_to_gather:
weight = self.SPG.get_weight(i, j)
mean_val += weight
std_val += weight ** 2
if weight > max_val:
max_val = weight
if weight < min_val:
min_val = weight
count += 1

if count > 100000000:
break

mean_val /= count
std_val = (std_val / count - mean_val ** 2) ** 0.5
return mean_val, std_val, max_val, min_val

def generate_indices(split, split_two):
if split_two is not None:
for i in range(len(split)):
to_compare = []
for j in range(len(split_two)):
to_compare.append((split[i], split_two[j]))
yield to_compare
else:
if len(index_to_gather) > 100000000:
values = self.SPG.get_weights(index_to_gather)
return torch.mean(values).item(), torch.std(values).item(), torch.max(values).item(), torch.min(values).item()
index_to_gather = torch.tensor(index_to_gather).cuda()
values = self.SPG.get_weights(index_to_gather)
return torch.mean(values).item(), torch.std(values).item(), torch.max(values).item(), torch.min(values).item()
for i in range(len(split)):
to_compare = []
for j in range(i+1, len(split)):
to_compare.append((split[i], split[j]))
yield to_compare

if split_two is None:
index_to_gather = [(split[i], split[j]) for i in range(len(split)) for j in range(i + 1, len(split))]
else:
index_to_gather = [(split[i], split_two[j]) for i in range(len(split)) for j in range(len(split_two))]
index_to_gather = generate_indices(split, split_two)

return calculate_overlap(index_to_gather)

Expand Down Expand Up @@ -106,11 +144,11 @@ def generate_spectra_split(self,
stats = self.get_stats(train, test, spectral_parameter)
return train, test, stats

def get_stats(self, train, test, spectral_parameter):
def get_stats(self, train, test, spectral_parameter, chunksize = 10000000, show_progress = False):
train_size = len(train)
test_size = len(test)
if not self.binary:
cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test))
cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test), chunksize, show_progress)
stats = {'SPECTRA_parameter': spectral_parameter,
'train_size': train_size,
'test_size': test_size,
Expand All @@ -119,7 +157,7 @@ def get_stats(self, train, test, spectral_parameter):
'max_css': max_css,
'min_css': min_css}
else:
cross_split_overlap, num_similar, num_total = self.cross_split_overlap(self.get_sample_indices(train))
cross_split_overlap, num_similar, num_total = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test), chunksize, show_progress)
stats = {'SPECTRA_parameter': spectral_parameter,
'train_size': train_size,
'test_size': test_size,
Expand Down Expand Up @@ -171,7 +209,9 @@ def generate_spectra_splits(self,

def return_split_stats(self, spectral_parameter: float,
number: int,
path_to_save: str = None):
path_to_save: str = None,
chunksize: int = 10000000,
show_progress: bool = False) -> Dict:

if path_to_save is None:
path_to_save = f"{self.dataset.name}_SPECTRA_splits"
Expand All @@ -183,7 +223,7 @@ def return_split_stats(self, spectral_parameter: float,
if not os.path.exists(f"{split_folder}/stats.pkl"):
train = pickle.load(open(f"{split_folder}/train.pkl", "rb"))
test = pickle.load(open(f"{split_folder}/test.pkl", "rb"))
stats = self.get_stats(train, test, spectral_parameter)
stats = self.get_stats(train, test, spectral_parameter, chunksize, show_progress)
pickle.dump(stats, open(f"{split_folder}/stats.pkl", "wb"))
return stats

Expand All @@ -206,6 +246,7 @@ def return_split_samples(self, spectral_parameter: float,

def return_all_split_stats(self,
path_to_save: str = None,
chunksize: int = 10000000,
show_progress: bool = False) -> Dict:

if path_to_save is None:
Expand All @@ -225,7 +266,7 @@ def return_all_split_stats(self,
for folder in to_iterate:
spectral_parameter = folder.split('_')[1]
number = folder.split('_')[2]
res = self.return_split_stats(spectral_parameter, number)
res = self.return_split_stats(spectral_parameter, number, chunksize=chunksize, show_progress=True)
SP.append(float(spectral_parameter))
numbers.append(int(number))
train_size.append(int(res['train_size']))
Expand Down
33 changes: 17 additions & 16 deletions spectrae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Tuple, Dict, Union, Optional
import os
import matplotlib.pyplot as plt
import pickle

class FlattenedAdjacency:
def __init__(self,
Expand Down Expand Up @@ -183,17 +184,14 @@ def cross_split_overlap(split, g):
values = g.get_weights(index_to_gather)
return torch.mean(values).item(), torch.std(values).item(), torch.max(values).item(), torch.min(values).item()

def output_split_stats(split_directory, g):
spectral_parameter = []
length = []
css = []
def output_split_stats(stats_file: str, name: str = None):
with open(stats_file, 'rb') as f:
stats = pickle.load(f)

for split_file in tqdm(os.listdir(split_directory)):
x = np.load(f'{split_directory}/{split_file}')
sp = split_file.split('_')[0]
spectral_parameter.append(sp)
length.append(len(x))
css.append(cross_split_overlap(x, g)[0])
spectral_parameter = stats['SPECTRA_parameter']
train_length = stats['train_size']
test_length = stats['test_size']
css = stats['cross_split_overlap']

# Convert spectral_parameter to a numeric type if necessary
spectral_parameter = list(map(float, spectral_parameter))
Expand All @@ -202,10 +200,13 @@ def output_split_stats(split_directory, g):
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

# Dataset size vs Spectral parameter
ax1.scatter(spectral_parameter, length, color='blue')
ax1.set_title('Dataset Size vs Spectral Parameter')
ax1.scatter(spectral_parameter, train_length, color='blue', label='Train')
ax1.scatter(spectral_parameter, test_length, color='green', label='Test')
ax1.set_title('Train and test set size vs Spectral Parameter')
ax1.set_xlabel('Spectral Parameter')
ax1.set_ylabel('Dataset Size')
ax1.legend()


# Cross split overlap vs Spectral parameter
ax2.scatter(spectral_parameter, css, color='red')
Expand All @@ -215,8 +216,8 @@ def output_split_stats(split_directory, g):

# Adjust layout and save the plot
plt.tight_layout()
plt.savefig('split_stats.png')
if name is not None:
plt.savefig(f'{name}.png')
else:
plt.savefig('split_stats.png')
plt.show()


return spectral_parameter, length, css

0 comments on commit b5a773d

Please sign in to comment.