Skip to content

Commit

Permalink
Add cuda condition
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 5, 2024
1 parent 346d4fd commit 3dde1de
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def fit_covariance_matrices_with_loader(
import gc
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"):
print(type(obj), obj.size())
except:
pass
Expand Down Expand Up @@ -245,7 +245,7 @@ def fit_covariance_matrices_with_loader(

for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"):
print(type(obj), obj.size())
except:
pass
Expand Down Expand Up @@ -278,7 +278,7 @@ def fit_covariance_matrices_with_loader(

for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"):
print(type(obj), obj.size())
except:
pass
Expand Down

0 comments on commit 3dde1de

Please sign in to comment.