diff --git a/spectrae/spectra.py b/spectrae/spectra.py index 913e0e1..375b8be 100644 --- a/spectrae/spectra.py +++ b/spectrae/spectra.py @@ -129,7 +129,8 @@ def generate_spectra_split(self, random_seed: int = 42, test_size: float = 0.2, degree_choosing: bool = False, - minimum: int = None): + minimum: int = None, + path_to_save: str = None): print(f"Generating SPECTRA split for spectral parameter {spectral_parameter} and dataset {self.dataset.name}") result = run_independent_set(spectral_parameter, self.SPG, @@ -143,13 +144,67 @@ def generate_spectra_split(self, print(f"Number of samples in independent set: {len(result)}") train, test = self.spectra_train_test_split(result, test_size=test_size, random_state=random_seed) stats = self.get_stats(train, test, spectral_parameter) - return train, test, stats + if path_to_save is None: + return train, test, stats + else: + i = 0 + if not os.path.exists(f"{path_to_save}/SP_{spectral_parameter}_{i}"): + os.makedirs(f"{path_to_save}/SP_{spectral_parameter}_{i}") + + pickle.dump(train, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/train.pkl", "wb")) + pickle.dump(test, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/test.pkl", "wb")) + pickle.dump(stats, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/stats.pkl", "wb")) - def get_stats(self, train, test, spectral_parameter, chunksize = 10000000, show_progress = False): + def get_stats(self, train: List, + test: List, + spectral_parameter: float, + chunksize: int = 10000000, + show_progress: bool = False, + sample_values: bool = False): + + """ + Computes statistics for the given train and test splits. + + Args: + train (List): A list of training sample IDs or sample indices. (see sample_values) + test (List): A list of test sample IDs or sample indices. (see sample_values) + spectral_parameter (float): The spectral parameter used for computation. + chunksize (int, optional): The size of chunks to process at a time. Default is 10,000,000. Decrease if you get a OOM error. + show_progress (bool, optional): Whether to show progress during computation. Default is False. + sample_values (bool, optional): True if you are passing sample IDs, False if you are passing sample indices. Default is False. + + Returns: + Dict[str, Any]: A dictionary containing the computed statistics. The keys and values depend on whether the data is binary or not. + If not binary: + - 'SPECTRA_parameter' (float): The spectral parameter used. + - 'train_size' (int): The size of the training set. + - 'test_size' (int): The size of the testing set. + - 'cross_split_overlap' (float): The cross-split overlap value. + - 'std_css' (float): The standard deviation of the cross-split similarity. + - 'max_css' (float): The maximum cross-split similarity. + - 'min_css' (float): The minimum cross-split similarity. + If binary: + - 'SPECTRA_parameter' (float): The spectral parameter used. + - 'train_size' (int): The size of the training set. + - 'test_size' (int): The size of the testing set. + - 'cross_split_overlap' (float): The cross-split overlap value. + - 'num_similar' (int): The number of similar items. + - 'num_total' (int): The total number of items. + + Raises: + ValueError: If the train or test lists are empty. + + """ + train_size = len(train) test_size = len(test) + + if sample_values: + train = self.get_sample_indices(train) + test = self.get_sample_indices(test) + if not self.binary: - cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test), chunksize, show_progress) + cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(train, test, chunksize, show_progress) stats = {'SPECTRA_parameter': spectral_parameter, 'train_size': train_size, 'test_size': test_size, @@ -158,7 +213,7 @@ def get_stats(self, train, test, spectral_parameter, chunksize = 10000000, show_ 'max_css': max_css, 'min_css': min_css} else: - cross_split_overlap, num_similar, num_total = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test), chunksize, show_progress) + cross_split_overlap, num_similar, num_total = self.cross_split_overlap(train, test, chunksize, show_progress) stats = {'SPECTRA_parameter': spectral_parameter, 'train_size': train_size, 'test_size': test_size,