Skip to content

Commit

Permalink
udpate
Browse files Browse the repository at this point in the history
  • Loading branch information
AllenWrong committed Mar 26, 2024
1 parent be24e24 commit 63b2af6
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import json
import json
from typing import List


def get_parser():
Expand Down Expand Up @@ -81,3 +82,24 @@ def setup_seed(seed):
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True


def tensor_to_device(tensor_x, device):
def _list_to_device(tensor_x: List[torch.Tensor], device):
for i in range(len(tensor_x)):
tensor_x[i] = tensor_x[i].to(device)

def _dict_to_device(fea_dict: dict, device):
for k, v in fea_dict.items():
if isinstance(v, torch.Tensor):
fea_dict[k] = v.to(device)

if isinstance(tensor_x, list):
_list_to_device(tensor_x, device)
elif isinstance(tensor_x, dict):
_dict_to_device(tensor_x, device)
elif isinstance(tensor_x, torch.Tensor):
return tensor_x.to(device)
else:
raise ValueError(f"unsupported type {type(tensor_x)}")
return tensor_x

0 comments on commit 63b2af6

Please sign in to comment.