diff --git a/kronfluence/computer/factor_computer.py b/kronfluence/computer/factor_computer.py index fbe9657..db8a76d 100644 --- a/kronfluence/computer/factor_computer.py +++ b/kronfluence/computer/factor_computer.py @@ -299,6 +299,19 @@ def fit_covariance_matrices( ) self._reset_memory() + + + import gc + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or ( + hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device( + "cuda"): + print(type(obj), obj.size()) + except: + pass + + start_time = get_time(state=self.state) with self.profiler.profile("Fit Covariance"): loader = self._get_dataloader( @@ -323,6 +336,18 @@ def fit_covariance_matrices( f"Fitted covariance matrices with {num_data_processed.item()} data points in " f"{elapsed_time:.2f} seconds." ) + self._reset_memory() + + print("Done") + import gc + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or ( + hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device( + "cuda"): + print(type(obj), obj.size()) + except: + pass with self.profiler.profile("Save Covariance"): if self.state.is_main_process: diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index 00da00c..e71b550 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -183,14 +183,6 @@ def fit_covariance_matrices_with_loader( - Number of data points processed. - Computed covariance matrices (nested dict: factor_name -> module_name -> tensor). """ - import gc - for obj in gc.get_objects(): - try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"): - print(type(obj), obj.size()) - except: - pass - update_factor_args(model=model, factor_args=factor_args) if tracked_module_names is None: tracked_module_names = get_tracked_module_names(model=model) @@ -243,13 +235,6 @@ def fit_covariance_matrices_with_loader( total_steps += 1 pbar.update(1) - for obj in gc.get_objects(): - try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"): - print(type(obj), obj.size()) - except: - pass - if state.use_distributed: synchronize_modules(model=model, tracked_module_names=tracked_module_names) num_data_processed = num_data_processed.to(device=state.device) @@ -276,11 +261,4 @@ def fit_covariance_matrices_with_loader( set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) state.wait_for_everyone() - for obj in gc.get_objects(): - try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"): - print(type(obj), obj.size()) - except: - pass - return num_data_processed, saved_factors