Skip to content

Commit

Permalink
Memory cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 5, 2024
1 parent c78c092 commit c135ecd
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 35 deletions.
30 changes: 1 addition & 29 deletions kronfluence/computer/factor_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from kronfluence.utils.dataset import DataLoaderKwargs, find_executable_batch_size
from kronfluence.utils.exceptions import FactorsNotFoundError
from kronfluence.utils.logger import get_time
from kronfluence.utils.state import release_memory


class FactorComputer(Computer):
Expand Down Expand Up @@ -298,20 +297,6 @@ def fit_covariance_matrices(
total_data_examples=max_partition_examples,
)

self._reset_memory()


import gc
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (
hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device(
"cuda"):
print(type(obj), obj.size(), obj.device)
except:
pass


start_time = get_time(state=self.state)
with self.profiler.profile("Fit Covariance"):
loader = self._get_dataloader(
Expand All @@ -337,17 +322,6 @@ def fit_covariance_matrices(
f"{elapsed_time:.2f} seconds."
)

print("Done")
import gc
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (
hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device(
"cuda"):
print(type(obj), obj.size(), obj.device)
except:
pass

with self.profiler.profile("Save Covariance"):
if self.state.is_main_process:
save_covariance_matrices(
Expand Down Expand Up @@ -456,7 +430,7 @@ def perform_eigendecomposition(
covariance_factors = load_covariance_matrices(output_dir=load_factors_output_dir)

if load_from_factors_name is not None and self.state.is_main_process:
# Saves the loaded covariances to the current factor output directory.
# Save the loaded covariances to the current factor output directory.
with self.profiler.profile("Save Covariance"):
save_covariance_matrices(output_dir=factors_output_dir, factors=covariance_factors)
loaded_factor_args = self.load_factor_args(factors_name=load_from_factors_name)
Expand All @@ -470,7 +444,6 @@ def perform_eigendecomposition(

eigen_factors = None
if self.state.is_main_process:
self._reset_memory()
start_time = time.time()
with self.profiler.profile("Perform Eigendecomposition"):
eigen_factors = perform_eigendecomposition(
Expand Down Expand Up @@ -672,7 +645,6 @@ def fit_lambda_matrices(
total_data_examples=max_partition_examples,
)

release_memory()
start_time = get_time(state=self.state)
with self.profiler.profile("Fit Lambda"):
loader = self._get_dataloader(
Expand Down
4 changes: 3 additions & 1 deletion kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
PARTITION_TYPE,
)
from kronfluence.utils.logger import TQDM_BAR_FORMAT
from kronfluence.utils.state import State, no_sync
from kronfluence.utils.state import State, no_sync, release_memory


def covariance_matrices_save_path(
Expand Down Expand Up @@ -192,6 +192,7 @@ def fit_covariance_matrices_with_loader(
mode=ModuleMode.COVARIANCE,
release_memory=True,
)
release_memory()

total_steps = 0
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
Expand Down Expand Up @@ -259,6 +260,7 @@ def fit_covariance_matrices_with_loader(
if enable_amp:
set_gradient_scale(model=model, gradient_scale=1.0)
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
release_memory()
state.wait_for_everyone()

return num_data_processed, saved_factors
7 changes: 5 additions & 2 deletions kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def perform_eigendecomposition(

pbar.update(1)

release_memory()

return eigen_factors


Expand Down Expand Up @@ -389,8 +391,8 @@ def fit_lambda_matrices_with_loader(
)
if eigen_factors is not None:
for name in eigen_factors:
set_factors(model=model, factor_name=name, factors=eigen_factors[name], clone=True)
del eigen_factors
set_factors(model=model, factor_name=name, factors=eigen_factors[name])
release_memory()

total_steps = 0
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
Expand Down Expand Up @@ -457,5 +459,6 @@ def fit_lambda_matrices_with_loader(
set_gradient_scale(model=model, gradient_scale=1.0)
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
state.wait_for_everyone()
release_memory()

return num_data_processed, saved_factors
2 changes: 1 addition & 1 deletion kronfluence/module/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def load_factors(
factor = module.get_factor(factor_name=factor_name)
if factor is not None:
if cpu:
loaded_factors[module.name] = factor.cpu()
loaded_factors[module.name] = factor.to(device="cpu", copy=True)
module.release_factor(factor_name=factor_name)
else:
loaded_factors[module.name] = factor
Expand Down
4 changes: 2 additions & 2 deletions kronfluence/score/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,9 @@ def compute_pairwise_query_aggregated_scores_with_loaders(
)
if len(loaded_factors) > 0:
for name in loaded_factors:
set_factors(model=model, factor_name=name, factors=loaded_factors[name], clone=True)
del loaded_factors
set_factors(model=model, factor_name=name, factors=loaded_factors[name])
prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device)
release_memory()

enable_amp = score_args.amp_dtype is not None
scaler = GradScaler(enabled=enable_amp)
Expand Down

0 comments on commit c135ecd

Please sign in to comment.