From 63b2af6b4f83e09fcc6ea6b992cca67c3bec3b98 Mon Sep 17 00:00:00 2001 From: AllenWrong <884691896@qq.com> Date: Tue, 26 Mar 2024 13:50:10 +0800 Subject: [PATCH] udpate --- utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/utils.py b/utils.py index c662d5c..3e680d5 100644 --- a/utils.py +++ b/utils.py @@ -6,6 +6,7 @@ import random import json import json +from typing import List def get_parser(): @@ -81,3 +82,24 @@ def setup_seed(seed): np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True + + +def tensor_to_device(tensor_x, device): + def _list_to_device(tensor_x: List[torch.Tensor], device): + for i in range(len(tensor_x)): + tensor_x[i] = tensor_x[i].to(device) + + def _dict_to_device(fea_dict: dict, device): + for k, v in fea_dict.items(): + if isinstance(v, torch.Tensor): + fea_dict[k] = v.to(device) + + if isinstance(tensor_x, list): + _list_to_device(tensor_x, device) + elif isinstance(tensor_x, dict): + _dict_to_device(tensor_x, device) + elif isinstance(tensor_x, torch.Tensor): + return tensor_x.to(device) + else: + raise ValueError(f"unsupported type {type(tensor_x)}") + return tensor_x