diff --git a/examples/igbh/dataset.py b/examples/igbh/dataset.py index a0d64da1..ad61d88a 100644 --- a/examples/igbh/dataset.py +++ b/examples/igbh/dataset.py @@ -17,8 +17,6 @@ import torch import os.path as osp -import graphlearn_torch as glt - from torch_geometric.utils import add_self_loops, remove_self_loops from download import download_dataset from typing import Literal diff --git a/examples/igbh/dist_train_rgnn.py b/examples/igbh/dist_train_rgnn.py index d1a43245..3e21581a 100644 --- a/examples/igbh/dist_train_rgnn.py +++ b/examples/igbh/dist_train_rgnn.py @@ -29,9 +29,6 @@ from rgnn import RGNN -torch.manual_seed(42) - - def evaluate(model, dataloader, current_device, use_fp16): predictions = [] labels = [] @@ -59,7 +56,7 @@ def evaluate(model, dataloader, current_device, use_fp16): def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, split_training_sampling, hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out, - epochs, batch_size, learning_rate, log_every, + epochs, batch_size, learning_rate, log_every, random_seed, dataset, train_idx, val_idx, master_addr, training_pg_master_port, @@ -67,6 +64,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, val_loader_master_port, with_gpu, trim_to_layer, use_fp16, edge_dir, rpc_timeout): + glt.utils.common.seed_everything(random_seed) # Initialize graphlearn_torch distributed worker group context. glt.distributed.init_worker_group( world_size=num_nodes*num_training_procs, @@ -107,6 +105,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, edge_dir=edge_dir, collect_features=True, to_device=current_device, + random_seed=random_seed, worker_options = glt.distributed.MpDistSamplingWorkerOptions( num_workers=1, worker_devices=sampling_device, @@ -130,6 +129,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, edge_dir=edge_dir, collect_features=True, to_device=current_device, + random_seed=random_seed, worker_options = glt.distributed.MpDistSamplingWorkerOptions( num_workers=1, worker_devices=sampling_device, @@ -261,6 +261,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--num_heads', type=int, default=4) parser.add_argument('--log_every', type=int, default=2) + parser.add_argument('--random_seed', type=int, default=42) # Distributed settings. parser.add_argument("--num_nodes", type=int, default=2, help="Number of distributed nodes.") @@ -292,6 +293,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, help="load node/edge feature using fp16 format to reduce memory usage") args = parser.parse_args() assert args.layout in ['COO', 'CSC', 'CSR'] + glt.utils.common.seed_everything(args.random_seed) # when set --cpu_mode or GPU is not available, use cpu only mode. args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available() if args.with_gpu: @@ -324,7 +326,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, run_training_proc, args=(args.num_nodes, args.node_rank, args.num_training_procs, args.split_training_sampling, args.hidden_channels, args.num_classes, args.num_layers, args.model, args.num_heads, args.fan_out, - args.epochs, args.batch_size, args.learning_rate, args.log_every, + args.epochs, args.batch_size, args.learning_rate, args.log_every, args.random_seed, dataset, train_idx, val_idx, args.master_addr, args.training_pg_master_port, diff --git a/examples/igbh/train_rgnn_multi_gpu.py b/examples/igbh/train_rgnn_multi_gpu.py index 765557e0..ebb3285b 100644 --- a/examples/igbh/train_rgnn_multi_gpu.py +++ b/examples/igbh/train_rgnn_multi_gpu.py @@ -28,7 +28,6 @@ from dataset import IGBHeteroDataset from rgnn import RGNN -torch.manual_seed(42) warnings.filterwarnings("ignore") def evaluate(model, dataloader, current_device): @@ -54,16 +53,16 @@ def evaluate(model, dataloader, current_device): def run_training_proc(rank, world_size, hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out, - epochs, batch_size, learning_rate, log_every, + epochs, batch_size, learning_rate, log_every, random_seed, dataset, train_idx, val_idx, with_gpu): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) torch.cuda.set_device(rank) - torch.manual_seed(42) + glt.utils.common.seed_everything(random_seed) current_device =torch.device(rank) - + print(f'Rank {rank} init graphlearn_torch NeighborLoader...') # Create rank neighbor loader for training train_idx = train_idx.split(train_idx.size(0) // world_size)[rank] @@ -74,7 +73,8 @@ def run_training_proc(rank, world_size, batch_size=batch_size, shuffle=True, drop_last=False, - device=current_device + device=current_device, + seed=random_seed ) # Create rank neighbor loader for validation. @@ -86,7 +86,8 @@ def run_training_proc(rank, world_size, batch_size=batch_size, shuffle=True, drop_last=False, - device=current_device + device=current_device, + seed=random_seed ) # Define model and optimizer. @@ -192,7 +193,7 @@ def run_training_proc(rank, world_size, parser.add_argument('--model', type=str, default='rgat', choices=['rgat', 'rsage']) # Model parameters - parser.add_argument('--fan_out', type=str, default='15,10') + parser.add_argument('--fan_out', type=str, default='15,10,5') parser.add_argument('--batch_size', type=int, default=1024) parser.add_argument('--hidden_channels', type=int, default=128) parser.add_argument('--learning_rate', type=float, default=0.01) @@ -200,6 +201,7 @@ def run_training_proc(rank, world_size, parser.add_argument('--num_layers', type=int, default=2) parser.add_argument('--num_heads', type=int, default=4) parser.add_argument('--log_every', type=int, default=5) + parser.add_argument('--random_seed', type=int, default=42) parser.add_argument("--cpu_mode", action="store_true", help="Only use CPU for sampling and training, default is False.") parser.add_argument("--edge_dir", type=str, default='in') @@ -212,6 +214,7 @@ def run_training_proc(rank, world_size, args = parser.parse_args() args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available() assert args.layout in ['COO', 'CSC', 'CSR'] + glt.utils.common.seed_everything(args.random_seed) igbh_dataset = IGBHeteroDataset(args.path, args.dataset_size, args.in_memory, args.num_classes==2983, True, args.layout, args.use_fp16) @@ -240,7 +243,7 @@ def run_training_proc(rank, world_size, run_training_proc, args=(world_size, args.hidden_channels, args.num_classes, args.num_layers, args.model, args.num_heads, args.fan_out, args.epochs, args.batch_size, - args.learning_rate, args.log_every, + args.learning_rate, args.log_every, args.random_seed, glt_dataset, train_idx, val_idx, args.with_gpu), nprocs=world_size, join=True diff --git a/graphlearn_torch/csrc/cpu/random_negative_sampler.cc b/graphlearn_torch/csrc/cpu/random_negative_sampler.cc index 05f40e1d..f9d3e127 100644 --- a/graphlearn_torch/csrc/cpu/random_negative_sampler.cc +++ b/graphlearn_torch/csrc/cpu/random_negative_sampler.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "graphlearn_torch/include/common.h" #include "graphlearn_torch/csrc/cpu/random_negative_sampler.h" #include -#include namespace graphlearn_torch { @@ -27,8 +27,8 @@ std::tuple CPURandomNegativeSampler::Sample( const int64_t* col_idx = graph_->GetColIdx(); int64_t row_num = graph_->GetRowCount(); int64_t col_num = graph_->GetColCount(); - thread_local static std::random_device rd; - thread_local static std::mt19937 engine(rd()); + uint32_t seed = RandomSeedManager::getInstance().getSeed(); + thread_local static std::mt19937 engine(seed); std::uniform_int_distribution row_dist(0, row_num - 1); std::uniform_int_distribution col_dist(0, col_num - 1); int64_t row_data[req_num]; diff --git a/graphlearn_torch/csrc/cpu/random_sampler.cc b/graphlearn_torch/csrc/cpu/random_sampler.cc index b9b6a1d7..49536802 100644 --- a/graphlearn_torch/csrc/cpu/random_sampler.cc +++ b/graphlearn_torch/csrc/cpu/random_sampler.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "graphlearn_torch/include/common.h" #include "graphlearn_torch/csrc/cpu/random_sampler.h" #include -#include #include namespace graphlearn_torch { @@ -141,8 +141,8 @@ void CPURandomSampler::UniformSample(const int64_t* col_begin, // with replacement const auto cap = col_end - col_begin; if (req_num < cap) { - thread_local static std::random_device rd; - thread_local static std::mt19937 engine(rd()); + uint32_t seed = RandomSeedManager::getInstance().getSeed(); + thread_local static std::mt19937 engine(seed); std::uniform_int_distribution<> dist(0, cap-1); for (int32_t i = 0; i < req_num; ++i) { out_nbrs[i] = col_begin[dist(engine)]; @@ -162,8 +162,8 @@ void CPURandomSampler::UniformSample(const int64_t* col_begin, // with replacement const auto cap = col_end - col_begin; if (req_num < cap) { - thread_local static std::random_device rd; - thread_local static std::mt19937 engine(rd()); + uint32_t seed = RandomSeedManager::getInstance().getSeed(); + thread_local static std::mt19937 engine(seed); std::uniform_int_distribution<> dist(0, cap-1); for (int32_t i = 0; i < req_num; ++i) { auto idx = dist(engine); diff --git a/graphlearn_torch/csrc/cpu/weighted_sampler.cc b/graphlearn_torch/csrc/cpu/weighted_sampler.cc index 7080d277..f4ac43ed 100644 --- a/graphlearn_torch/csrc/cpu/weighted_sampler.cc +++ b/graphlearn_torch/csrc/cpu/weighted_sampler.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "graphlearn_torch/include/common.h" #include "graphlearn_torch/csrc/cpu/weighted_sampler.h" #include -#include #include @@ -153,8 +153,8 @@ void CPUWeightedSampler::WeightedSample(const int64_t* col_begin, // with replacement const auto cap = col_end - col_begin; if (req_num < cap) { - thread_local static std::random_device rd; - thread_local static std::mt19937 engine(rd()); + uint32_t seed = RandomSeedManager::getInstance().getSeed(); + thread_local static std::mt19937 engine(seed); std::discrete_distribution<> dist(prob_begin, prob_end); for (int32_t i = 0; i < req_num; ++i) { out_nbrs[i] = col_begin[dist(engine)]; @@ -176,8 +176,8 @@ void CPUWeightedSampler::WeightedSample(const int64_t* col_begin, // with replacement const auto cap = col_end - col_begin; if (req_num < cap) { - thread_local static std::random_device rd; - thread_local static std::mt19937 engine(rd()); + uint32_t seed = RandomSeedManager::getInstance().getSeed(); + thread_local static std::mt19937 engine(seed); std::discrete_distribution<> dist(prob_begin, prob_end); for (int32_t i = 0; i < req_num; ++i) { auto idx = dist(engine); diff --git a/graphlearn_torch/csrc/cuda/random_negative_sampler.cu b/graphlearn_torch/csrc/cuda/random_negative_sampler.cu index 481dcd7a..9db098af 100644 --- a/graphlearn_torch/csrc/cuda/random_negative_sampler.cu +++ b/graphlearn_torch/csrc/cuda/random_negative_sampler.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "graphlearn_torch/include/common.h" #include "graphlearn_torch/csrc/cuda/random_negative_sampler.cuh" #include @@ -21,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -135,8 +135,8 @@ CUDARandomNegativeSampler::Sample(int32_t req_num, int block_size = 0; int grid_size = 0; - thread_local static std::random_device rd; - thread_local static std::mt19937 engine(rd()); + uint32_t seed = RandomSeedManager::getInstance().getSeed(); + thread_local static std::mt19937 engine(seed); std::uniform_int_distribution dist(0, 1e10); cudaOccupancyMaxPotentialBlockSize( &grid_size, &block_size, RandomNegativeSampleKernel); diff --git a/graphlearn_torch/csrc/cuda/random_sampler.cu b/graphlearn_torch/csrc/cuda/random_sampler.cu index c5e86b85..84f16058 100644 --- a/graphlearn_torch/csrc/cuda/random_sampler.cu +++ b/graphlearn_torch/csrc/cuda/random_sampler.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "graphlearn_torch/include/common.h" #include "graphlearn_torch/csrc/cuda/random_sampler.cuh" #include @@ -21,7 +22,6 @@ limitations under the License. #include #include #include -#include #include "graphlearn_torch/include/common.cuh" @@ -233,8 +233,8 @@ void CSRRowWiseSample(cudaStream_t stream, int64_t* out_nbrs) { const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); - thread_local static std::random_device rd; - thread_local static std::mt19937 engine(rd()); + uint32_t seed = RandomSeedManager::getInstance().getSeed(); + thread_local static std::mt19937 engine(seed); std::uniform_int_distribution dist(0, 1e10); CSRRowWiseSampleKernel<<>>( dist(engine), req_num, bs, row_count, @@ -255,8 +255,8 @@ void CSRRowWiseSample(cudaStream_t stream, int64_t* out_eid) { const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); - thread_local static std::random_device rd; - thread_local static std::mt19937 engine(rd()); + uint32_t seed = RandomSeedManager::getInstance().getSeed(); + thread_local static std::mt19937 engine(seed); std::uniform_int_distribution dist(0, 1e10); CSRRowWiseSampleKernel<<>>( dist(engine), req_num, bs, row_count, nodes, row_ptr, col_idx, edge_ids, diff --git a/graphlearn_torch/include/common.h b/graphlearn_torch/include/common.h index 471f8b88..156969d7 100644 --- a/graphlearn_torch/include/common.h +++ b/graphlearn_torch/include/common.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef GRAPHLEARN_TORCH_INCLUDE_COMMON_H_ #define GRAPHLEARN_TORCH_INCLUDE_COMMON_H_ +#include #include namespace graphlearn_torch @@ -32,6 +33,38 @@ inline void CheckEq(const T &x, const T &y) { throw std::runtime_error(std::string("CheckEq failed")); } +class RandomSeedManager { +public: + static RandomSeedManager& getInstance() { + static RandomSeedManager instance; + return instance; + } + + void setSeed(uint32_t seed) { + this->is_set = true; + this->seed = seed; + } + + uint32_t getSeed() const { + if (this->is_set) { + return seed; + } + else { + std::random_device rd; + return rd(); + } + } + +private: + RandomSeedManager() {} // Constructor is private + RandomSeedManager(RandomSeedManager const&) = delete; // Prevent copies + void operator=(RandomSeedManager const&) = delete; // Prevent assignments + + uint32_t seed; + bool is_set = false; +}; + + } // namespace graphlearn_torch #endif // GRAPHLEARN_TORCH_INCLUDE_COMMON_H_ diff --git a/graphlearn_torch/python/distributed/dist_neighbor_loader.py b/graphlearn_torch/python/distributed/dist_neighbor_loader.py index 6a33bf7e..da4ed3d2 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_loader.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_loader.py @@ -83,6 +83,7 @@ def __init__(self, edge_dir: Literal['in', 'out'] = 'out', collect_features: bool = False, to_device: Optional[torch.device] = None, + random_seed: int = None, worker_options: Optional[AllDistSamplingWorkerOptions] = None): if isinstance(input_nodes, tuple): @@ -109,7 +110,7 @@ def __init__(self, sampling_config = SamplingConfig( SamplingType.NODE, num_neighbors, batch_size, shuffle, drop_last, with_edge, collect_features, with_neg=False, - with_weight=with_weight, edge_dir=edge_dir + with_weight=with_weight, edge_dir=edge_dir, seed=random_seed ) super().__init__( diff --git a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py index 4093624c..79d2cb70 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py @@ -124,7 +124,8 @@ def __init__(self, collect_features: bool = False, channel: Optional[ChannelBase] = None, concurrency: int = 1, - device: Optional[torch.device] = None): + device: Optional[torch.device] = None, + seed:int = None): self.data = data self.num_neighbors = num_neighbors self.max_input_size = 0 @@ -136,6 +137,7 @@ def __init__(self, self.channel = channel self.concurrency = concurrency self.device = get_available_device(device) + self.seed = seed if isinstance(data, DistDataset): partition2workers = rpc_sync_data_partitions( @@ -170,7 +172,8 @@ def __init__(self, self.sampler = NeighborSampler( self.dist_graph.local_graph, self.num_neighbors, - self.device, self.with_edge, self.with_neg, self.with_weight, self.edge_dir + self.device, self.with_edge, self.with_neg, self.with_weight, + self.edge_dir, seed=self.seed ) self.inducer_pool = queue.Queue(maxsize=self.concurrency) diff --git a/graphlearn_torch/python/distributed/dist_sampling_producer.py b/graphlearn_torch/python/distributed/dist_sampling_producer.py index 956e2660..51e1463b 100644 --- a/graphlearn_torch/python/distributed/dist_sampling_producer.py +++ b/graphlearn_torch/python/distributed/dist_sampling_producer.py @@ -28,6 +28,7 @@ NodeSamplerInput, EdgeSamplerInput, SamplingType, SamplingConfig ) from ..utils import ensure_device +from ..utils import seed_everything from ..distributed.dist_context import get_context from .dist_context import init_worker_group @@ -88,11 +89,13 @@ def _sampling_worker_loop(rank, rpc_timeout=worker_options.rpc_timeout ) + if sampling_config.seed is not None: + seed_everything(sampling_config.seed) dist_sampler = DistNeighborSampler( data, sampling_config.num_neighbors, sampling_config.with_edge, sampling_config.with_neg, sampling_config.with_weight, sampling_config.edge_dir, sampling_config.collect_features, channel, - worker_options.worker_concurrency, current_device + worker_options.worker_concurrency, current_device, seed=sampling_config.seed ) dist_sampler.start_loop() @@ -184,6 +187,8 @@ def __init__(self, def init(self): r""" Create the subprocess pool. Init samplers and rpc server. """ + if self.sampling_config.seed is not None: + seed_everything(self.sampling_config.seed) if not self.sampling_config.shuffle: unshuffled_indexes = self._get_seeds_indexes() else: @@ -326,7 +331,8 @@ def init(self): self.sampling_config.with_edge, self.sampling_config.with_neg, self.sampling_config.with_weight, self.sampling_config.edge_dir, self.sampling_config.collect_features, - channel=None, concurrency=1, device=self.device + channel=None, concurrency=1, device=self.device, + seed=self.sampling_config.seed ) self._collocated_sampler.start_loop() diff --git a/graphlearn_torch/python/loader/neighbor_loader.py b/graphlearn_torch/python/loader/neighbor_loader.py index eaf68c93..6bbe6357 100644 --- a/graphlearn_torch/python/loader/neighbor_loader.py +++ b/graphlearn_torch/python/loader/neighbor_loader.py @@ -70,6 +70,7 @@ def __init__( strategy: str = 'random', device: torch.device = torch.device(0), as_pyg_v1: bool = False, + seed: int = None, **kwargs ): if neighbor_sampler is None: @@ -81,6 +82,7 @@ def __init__( with_weight=with_weight, device=device, edge_dir=data.edge_dir, + seed=seed ) self.as_pyg_v1 = as_pyg_v1 self.edge_dir = data.edge_dir diff --git a/graphlearn_torch/python/py_export_glt.cc b/graphlearn_torch/python/py_export_glt.cc index b96a145b..e70ab36f 100644 --- a/graphlearn_torch/python/py_export_glt.cc +++ b/graphlearn_torch/python/py_export_glt.cc @@ -25,6 +25,7 @@ limitations under the License. #include "graphlearn_torch/csrc/cpu/random_sampler.h" #include "graphlearn_torch/csrc/cpu/weighted_sampler.h" #include "graphlearn_torch/csrc/cpu/subgraph_op.h" +#include "graphlearn_torch/include/common.h" #include "graphlearn_torch/include/graph.h" #include "graphlearn_torch/include/negative_sampler.h" #include "graphlearn_torch/include/sample_queue.h" @@ -79,6 +80,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("rows", &SubGraph::rows) .def_readwrite("cols", &SubGraph::cols) .def_readwrite("eids", &SubGraph::eids); + + py::class_(m, "RandomSeedManager") + .def_static("getInstance", &RandomSeedManager::getInstance, py::return_value_policy::reference) + .def("setSeed", &RandomSeedManager::setSeed, py::arg("seed")) + .def("getSeed", &RandomSeedManager::getSeed); py::class_(m, "CPURandomSampler") .def(py::init()) diff --git a/graphlearn_torch/python/sampler/base.py b/graphlearn_torch/python/sampler/base.py index 8fa54a32..5b831b3b 100644 --- a/graphlearn_torch/python/sampler/base.py +++ b/graphlearn_torch/python/sampler/base.py @@ -349,6 +349,7 @@ class SamplingConfig: with_neg: bool with_weight: bool edge_dir: Literal['in', 'out'] + seed: int class BaseSampler(ABC): diff --git a/graphlearn_torch/python/sampler/neighbor_sampler.py b/graphlearn_torch/python/sampler/neighbor_sampler.py index bbf91dcd..2cb07ec0 100644 --- a/graphlearn_torch/python/sampler/neighbor_sampler.py +++ b/graphlearn_torch/python/sampler/neighbor_sampler.py @@ -45,7 +45,8 @@ def __init__(self, with_neg: bool=False, with_weight: bool=False, strategy: str = 'random', - edge_dir: Literal['in', 'out'] = 'out'): + edge_dir: Literal['in', 'out'] = 'out', + seed: int = None): self.graph = graph self.num_neighbors = num_neighbors self.device = device @@ -58,6 +59,9 @@ def __init__(self, self._sampler = None self._neg_sampler = None self._inducer = None + + if seed is not None: + pywrap.RandomSeedManager.getInstance().setSeed(seed) if isinstance(self.graph, Graph): #homo self._g_cls = 'homo' if self.graph.mode == 'CPU': diff --git a/graphlearn_torch/python/utils/common.py b/graphlearn_torch/python/utils/common.py index 8f1a9054..e27b5ed1 100644 --- a/graphlearn_torch/python/utils/common.py +++ b/graphlearn_torch/python/utils/common.py @@ -19,6 +19,8 @@ from ..typing import reverse_edge_type from .tensor import id2idx +import numpy +import random import torch import pickle @@ -26,7 +28,18 @@ def ensure_dir(dir_path: str): if not os.path.exists(dir_path): os.makedirs(dir_path) - +def seed_everything(seed: int): + r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, + :obj:`numpy` and :python:`Python`. + + Args: + seed (int): The desired seed. + """ + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + def merge_dict(in_dict: Dict[Any, Any], out_dict: Dict[Any, Any]): for k, v in in_dict.items(): vals = out_dict.get(k, [])