Skip to content

Commit

Permalink
Call torch.cuda.empty_cache to release device memory (#114663)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#114663
Approved by: https://github.com/eellison

Reviewed By: osalpekar

Differential Revision: D52062377

Pulled By: desertfire

fbshipit-source-id: a4d6fdc530046d8d7c2ab259398f6deb1daa3ded
  • Loading branch information
desertfire authored and facebook-github-bot committed Dec 12, 2023
1 parent d01dc3a commit cd3cfb9
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,8 +1942,6 @@ def validate_model(self, model, example_inputs):
raise RuntimeError("Eager run failed") from e

def maybe_cast(self, model, example_inputs):
model = self.deepcopy_model(model)
example_inputs = clone_inputs(example_inputs)
model, example_inputs = self.cast_based_on_args(model, example_inputs)
return model, example_inputs

Expand Down Expand Up @@ -2145,7 +2143,8 @@ def record_status(accuracy_status, dynamo_start_stats):
self.args.cosine = True
fp64_outputs = None
finally:
del model_fp64
del model_fp64, inputs_fp64
torch.cuda.empty_cache()

tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
self.args.training, current_device, name
Expand Down Expand Up @@ -2174,6 +2173,7 @@ def record_status(accuracy_status, dynamo_start_stats):
return record_status(accuracy_status, dynamo_start_stats=start_stats)
finally:
del model_copy
torch.cuda.empty_cache()

# Rerun native pytorch
reset_rng_state()
Expand All @@ -2194,6 +2194,7 @@ def record_status(accuracy_status, dynamo_start_stats):
return record_status(accuracy_status, dynamo_start_stats=start_stats)
finally:
del model_copy
torch.cuda.empty_cache()

# Two eager runs should have exactly same result
is_same = True
Expand Down

0 comments on commit cd3cfb9

Please sign in to comment.