diff --git a/spectrae/__init__.py b/spectrae/__init__.py index 4f1dbd8..1e7031f 100644 --- a/spectrae/__init__.py +++ b/spectrae/__init__.py @@ -1,3 +1,3 @@ from .spectra import Spectra, Spectra_Property_Graph_Constructor from .dataset import SpectraDataset -from .utils import Spectral_Property_Graph, FlattenedAdjacency \ No newline at end of file +from .utils import Spectral_Property_Graph, FlattenedAdjacency, output_split_stats diff --git a/spectrae/dataset.py b/spectrae/dataset.py index 1479f8d..03edba7 100644 --- a/spectrae/dataset.py +++ b/spectrae/dataset.py @@ -9,6 +9,7 @@ def __init__(self, input_file, name): self.sample_to_index = self.parse(input_file) self.samples = list(self.sample_to_index.keys()) self.samples.sort() + self.index_map = {value: idx for idx, value in enumerate(self.samples)} @abstractmethod def parse(self, input_file: str) -> Dict: @@ -35,6 +36,6 @@ def index(self, value): """ Given a value, return the index of that value """ - if value not in self.samples: + if value not in self.index_map: raise ValueError(f"{value} not in the dataset") - return self.samples.index(value) \ No newline at end of file + return self.index_map[value] \ No newline at end of file diff --git a/spectrae/spectra.py b/spectrae/spectra.py index 40f0330..913e0e1 100644 --- a/spectrae/spectra.py +++ b/spectrae/spectra.py @@ -35,24 +35,62 @@ def spectra_properties(self, sample_one, sample_two): def cross_split_overlap(self, split: List[int], - split_two: Optional[List[int]] = None) -> Tuple[float, float, float]: + split_two: Optional[List[int]] = None, + chunksize: int = 10000000, + show_progress: bool = False) -> Tuple[float, float, float]: def calculate_overlap(index_to_gather): if self.SPG.binary: - num_similar = sum(1 for i, j in index_to_gather if self.SPG.get_weight(i, j) > 0) - return num_similar / len(split), num_similar, len(split) + num_similar = 0 + + if show_progress: + index_to_gather = tqdm(index_to_gather, total = len(split)) + else: + index_to_gather = index_to_gather + + for compare_list in index_to_gather: + if self.SPG.get_weights(compare_list).sum() > 0: + num_similar += 1 + + return num_similar/(len(split)), num_similar, len(split) + else: + mean_val = 0.0 + std_val = 0.0 + max_val = float('-inf') + min_val = float('inf') + count = 0 + for i, j in index_to_gather: + weight = self.SPG.get_weight(i, j) + mean_val += weight + std_val += weight ** 2 + if weight > max_val: + max_val = weight + if weight < min_val: + min_val = weight + count += 1 + + if count > 100000000: + break + + mean_val /= count + std_val = (std_val / count - mean_val ** 2) ** 0.5 + return mean_val, std_val, max_val, min_val + + def generate_indices(split, split_two): + if split_two is not None: + for i in range(len(split)): + to_compare = [] + for j in range(len(split_two)): + to_compare.append((split[i], split_two[j])) + yield to_compare else: - if len(index_to_gather) > 100000000: - values = self.SPG.get_weights(index_to_gather) - return torch.mean(values).item(), torch.std(values).item(), torch.max(values).item(), torch.min(values).item() - index_to_gather = torch.tensor(index_to_gather).cuda() - values = self.SPG.get_weights(index_to_gather) - return torch.mean(values).item(), torch.std(values).item(), torch.max(values).item(), torch.min(values).item() + for i in range(len(split)): + to_compare = [] + for j in range(i+1, len(split)): + to_compare.append((split[i], split[j])) + yield to_compare - if split_two is None: - index_to_gather = [(split[i], split[j]) for i in range(len(split)) for j in range(i + 1, len(split))] - else: - index_to_gather = [(split[i], split_two[j]) for i in range(len(split)) for j in range(len(split_two))] + index_to_gather = generate_indices(split, split_two) return calculate_overlap(index_to_gather) @@ -107,11 +145,11 @@ def generate_spectra_split(self, stats = self.get_stats(train, test, spectral_parameter) return train, test, stats - def get_stats(self, train, test, spectral_parameter): + def get_stats(self, train, test, spectral_parameter, chunksize = 10000000, show_progress = False): train_size = len(train) test_size = len(test) if not self.binary: - cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test)) + cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test), chunksize, show_progress) stats = {'SPECTRA_parameter': spectral_parameter, 'train_size': train_size, 'test_size': test_size, @@ -120,7 +158,7 @@ def get_stats(self, train, test, spectral_parameter): 'max_css': max_css, 'min_css': min_css} else: - cross_split_overlap, num_similar, num_total = self.cross_split_overlap(self.get_sample_indices(train)) + cross_split_overlap, num_similar, num_total = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test), chunksize, show_progress) stats = {'SPECTRA_parameter': spectral_parameter, 'train_size': train_size, 'test_size': test_size, @@ -172,7 +210,9 @@ def generate_spectra_splits(self, def return_split_stats(self, spectral_parameter: float, number: int, - path_to_save: str = None): + path_to_save: str = None, + chunksize: int = 10000000, + show_progress: bool = False) -> Dict: if path_to_save is None: path_to_save = f"{self.dataset.name}_SPECTRA_splits" @@ -184,7 +224,7 @@ def return_split_stats(self, spectral_parameter: float, if not os.path.exists(f"{split_folder}/stats.pkl"): train = pickle.load(open(f"{split_folder}/train.pkl", "rb")) test = pickle.load(open(f"{split_folder}/test.pkl", "rb")) - stats = self.get_stats(train, test, spectral_parameter) + stats = self.get_stats(train, test, spectral_parameter, chunksize, show_progress) pickle.dump(stats, open(f"{split_folder}/stats.pkl", "wb")) return stats @@ -207,6 +247,7 @@ def return_split_samples(self, spectral_parameter: float, def return_all_split_stats(self, path_to_save: str = None, + chunksize: int = 10000000, show_progress: bool = False) -> Dict: if path_to_save is None: @@ -226,7 +267,7 @@ def return_all_split_stats(self, for folder in to_iterate: spectral_parameter = folder.split('_')[1] number = folder.split('_')[2] - res = self.return_split_stats(spectral_parameter, number) + res = self.return_split_stats(spectral_parameter, number, chunksize=chunksize, show_progress=True) SP.append(float(spectral_parameter)) numbers.append(int(number)) train_size.append(int(res['train_size'])) diff --git a/spectrae/utils.py b/spectrae/utils.py index 67daeb4..e78ead3 100644 --- a/spectrae/utils.py +++ b/spectrae/utils.py @@ -4,6 +4,7 @@ from typing import List, Tuple, Dict, Union, Optional import os import matplotlib.pyplot as plt +import pickle class FlattenedAdjacency: def __init__(self, @@ -183,17 +184,14 @@ def cross_split_overlap(split, g): values = g.get_weights(index_to_gather) return torch.mean(values).item(), torch.std(values).item(), torch.max(values).item(), torch.min(values).item() -def output_split_stats(split_directory, g): - spectral_parameter = [] - length = [] - css = [] +def output_split_stats(stats_file: str, name: str = None): + with open(stats_file, 'rb') as f: + stats = pickle.load(f) - for split_file in tqdm(os.listdir(split_directory)): - x = np.load(f'{split_directory}/{split_file}') - sp = split_file.split('_')[0] - spectral_parameter.append(sp) - length.append(len(x)) - css.append(cross_split_overlap(x, g)[0]) + spectral_parameter = stats['SPECTRA_parameter'] + train_length = stats['train_size'] + test_length = stats['test_size'] + css = stats['cross_split_overlap'] # Convert spectral_parameter to a numeric type if necessary spectral_parameter = list(map(float, spectral_parameter)) @@ -202,10 +200,13 @@ def output_split_stats(split_directory, g): fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8)) # Dataset size vs Spectral parameter - ax1.scatter(spectral_parameter, length, color='blue') - ax1.set_title('Dataset Size vs Spectral Parameter') + ax1.scatter(spectral_parameter, train_length, color='blue', label='Train') + ax1.scatter(spectral_parameter, test_length, color='green', label='Test') + ax1.set_title('Train and test set size vs Spectral Parameter') ax1.set_xlabel('Spectral Parameter') ax1.set_ylabel('Dataset Size') + ax1.legend() + # Cross split overlap vs Spectral parameter ax2.scatter(spectral_parameter, css, color='red') @@ -215,8 +216,8 @@ def output_split_stats(split_directory, g): # Adjust layout and save the plot plt.tight_layout() - plt.savefig('split_stats.png') + if name is not None: + plt.savefig(f'{name}.png') + else: + plt.savefig('split_stats.png') plt.show() - - - return spectral_parameter, length, css