From 3dde1de393143aa70eb66d852d5c13e8aad152fb Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 5 Jul 2024 02:06:15 -0400 Subject: [PATCH] Add cuda condition --- kronfluence/factor/covariance.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index 5b216ef..00da00c 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -186,7 +186,7 @@ def fit_covariance_matrices_with_loader( import gc for obj in gc.get_objects(): try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"): print(type(obj), obj.size()) except: pass @@ -245,7 +245,7 @@ def fit_covariance_matrices_with_loader( for obj in gc.get_objects(): try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"): print(type(obj), obj.size()) except: pass @@ -278,7 +278,7 @@ def fit_covariance_matrices_with_loader( for obj in gc.get_objects(): try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"): print(type(obj), obj.size()) except: pass