Skip to content

Commit

Permalink
Expose random seed configuration for single-node and distributed trai…
Browse files Browse the repository at this point in the history
…ning (#106)

Expose random seed configuration for single-node and distributed training
  • Loading branch information
LiSu authored Dec 12, 2023
1 parent 6e965af commit 3f7731d
Show file tree
Hide file tree
Showing 17 changed files with 115 additions and 43 deletions.
2 changes: 0 additions & 2 deletions examples/igbh/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import torch
import os.path as osp

import graphlearn_torch as glt

from torch_geometric.utils import add_self_loops, remove_self_loops
from download import download_dataset
from typing import Literal
Expand Down
12 changes: 7 additions & 5 deletions examples/igbh/dist_train_rgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
from rgnn import RGNN


torch.manual_seed(42)


def evaluate(model, dataloader, current_device, use_fp16):
predictions = []
labels = []
Expand Down Expand Up @@ -59,14 +56,15 @@ def evaluate(model, dataloader, current_device, use_fp16):

def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
split_training_sampling, hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out,
epochs, batch_size, learning_rate, log_every,
epochs, batch_size, learning_rate, log_every, random_seed,
dataset, train_idx, val_idx,
master_addr,
training_pg_master_port,
train_loader_master_port,
val_loader_master_port,
with_gpu, trim_to_layer, use_fp16,
edge_dir, rpc_timeout):
glt.utils.common.seed_everything(random_seed)
# Initialize graphlearn_torch distributed worker group context.
glt.distributed.init_worker_group(
world_size=num_nodes*num_training_procs,
Expand Down Expand Up @@ -107,6 +105,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
edge_dir=edge_dir,
collect_features=True,
to_device=current_device,
random_seed=random_seed,
worker_options = glt.distributed.MpDistSamplingWorkerOptions(
num_workers=1,
worker_devices=sampling_device,
Expand All @@ -130,6 +129,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
edge_dir=edge_dir,
collect_features=True,
to_device=current_device,
random_seed=random_seed,
worker_options = glt.distributed.MpDistSamplingWorkerOptions(
num_workers=1,
worker_devices=sampling_device,
Expand Down Expand Up @@ -261,6 +261,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--num_heads', type=int, default=4)
parser.add_argument('--log_every', type=int, default=2)
parser.add_argument('--random_seed', type=int, default=42)
# Distributed settings.
parser.add_argument("--num_nodes", type=int, default=2,
help="Number of distributed nodes.")
Expand Down Expand Up @@ -292,6 +293,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
help="load node/edge feature using fp16 format to reduce memory usage")
args = parser.parse_args()
assert args.layout in ['COO', 'CSC', 'CSR']
glt.utils.common.seed_everything(args.random_seed)
# when set --cpu_mode or GPU is not available, use cpu only mode.
args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available()
if args.with_gpu:
Expand Down Expand Up @@ -324,7 +326,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
run_training_proc,
args=(args.num_nodes, args.node_rank, args.num_training_procs, args.split_training_sampling,
args.hidden_channels, args.num_classes, args.num_layers, args.model, args.num_heads, args.fan_out,
args.epochs, args.batch_size, args.learning_rate, args.log_every,
args.epochs, args.batch_size, args.learning_rate, args.log_every, args.random_seed,
dataset, train_idx, val_idx,
args.master_addr,
args.training_pg_master_port,
Expand Down
19 changes: 11 additions & 8 deletions examples/igbh/train_rgnn_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from dataset import IGBHeteroDataset
from rgnn import RGNN

torch.manual_seed(42)
warnings.filterwarnings("ignore")

def evaluate(model, dataloader, current_device):
Expand All @@ -54,16 +53,16 @@ def evaluate(model, dataloader, current_device):

def run_training_proc(rank, world_size,
hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out,
epochs, batch_size, learning_rate, log_every,
epochs, batch_size, learning_rate, log_every, random_seed,
dataset, train_idx, val_idx, with_gpu):

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
torch.manual_seed(42)
glt.utils.common.seed_everything(random_seed)
current_device =torch.device(rank)

print(f'Rank {rank} init graphlearn_torch NeighborLoader...')
# Create rank neighbor loader for training
train_idx = train_idx.split(train_idx.size(0) // world_size)[rank]
Expand All @@ -74,7 +73,8 @@ def run_training_proc(rank, world_size,
batch_size=batch_size,
shuffle=True,
drop_last=False,
device=current_device
device=current_device,
seed=random_seed
)

# Create rank neighbor loader for validation.
Expand All @@ -86,7 +86,8 @@ def run_training_proc(rank, world_size,
batch_size=batch_size,
shuffle=True,
drop_last=False,
device=current_device
device=current_device,
seed=random_seed
)

# Define model and optimizer.
Expand Down Expand Up @@ -192,14 +193,15 @@ def run_training_proc(rank, world_size,
parser.add_argument('--model', type=str, default='rgat',
choices=['rgat', 'rsage'])
# Model parameters
parser.add_argument('--fan_out', type=str, default='15,10')
parser.add_argument('--fan_out', type=str, default='15,10,5')
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--num_heads', type=int, default=4)
parser.add_argument('--log_every', type=int, default=5)
parser.add_argument('--random_seed', type=int, default=42)
parser.add_argument("--cpu_mode", action="store_true",
help="Only use CPU for sampling and training, default is False.")
parser.add_argument("--edge_dir", type=str, default='in')
Expand All @@ -212,6 +214,7 @@ def run_training_proc(rank, world_size,
args = parser.parse_args()
args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available()
assert args.layout in ['COO', 'CSC', 'CSR']
glt.utils.common.seed_everything(args.random_seed)
igbh_dataset = IGBHeteroDataset(args.path, args.dataset_size, args.in_memory,
args.num_classes==2983, True, args.layout, args.use_fp16)

Expand Down Expand Up @@ -240,7 +243,7 @@ def run_training_proc(rank, world_size,
run_training_proc,
args=(world_size, args.hidden_channels, args.num_classes, args.num_layers,
args.model, args.num_heads, args.fan_out, args.epochs, args.batch_size,
args.learning_rate, args.log_every,
args.learning_rate, args.log_every, args.random_seed,
glt_dataset, train_idx, val_idx, args.with_gpu),
nprocs=world_size,
join=True
Expand Down
6 changes: 3 additions & 3 deletions graphlearn_torch/csrc/cpu/random_negative_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "graphlearn_torch/include/common.h"
#include "graphlearn_torch/csrc/cpu/random_negative_sampler.h"

#include <algorithm>
#include <random>


namespace graphlearn_torch {
Expand All @@ -27,8 +27,8 @@ std::tuple<torch::Tensor, torch::Tensor> CPURandomNegativeSampler::Sample(
const int64_t* col_idx = graph_->GetColIdx();
int64_t row_num = graph_->GetRowCount();
int64_t col_num = graph_->GetColCount();
thread_local static std::random_device rd;
thread_local static std::mt19937 engine(rd());
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::uniform_int_distribution<int64_t> row_dist(0, row_num - 1);
std::uniform_int_distribution<int64_t> col_dist(0, col_num - 1);
int64_t row_data[req_num];
Expand Down
10 changes: 5 additions & 5 deletions graphlearn_torch/csrc/cpu/random_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "graphlearn_torch/include/common.h"
#include "graphlearn_torch/csrc/cpu/random_sampler.h"

#include <cstdint>
#include <random>
#include <cassert>

namespace graphlearn_torch {
Expand Down Expand Up @@ -141,8 +141,8 @@ void CPURandomSampler::UniformSample(const int64_t* col_begin,
// with replacement
const auto cap = col_end - col_begin;
if (req_num < cap) {
thread_local static std::random_device rd;
thread_local static std::mt19937 engine(rd());
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::uniform_int_distribution<> dist(0, cap-1);
for (int32_t i = 0; i < req_num; ++i) {
out_nbrs[i] = col_begin[dist(engine)];
Expand All @@ -162,8 +162,8 @@ void CPURandomSampler::UniformSample(const int64_t* col_begin,
// with replacement
const auto cap = col_end - col_begin;
if (req_num < cap) {
thread_local static std::random_device rd;
thread_local static std::mt19937 engine(rd());
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::uniform_int_distribution<> dist(0, cap-1);
for (int32_t i = 0; i < req_num; ++i) {
auto idx = dist(engine);
Expand Down
10 changes: 5 additions & 5 deletions graphlearn_torch/csrc/cpu/weighted_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "graphlearn_torch/include/common.h"
#include "graphlearn_torch/csrc/cpu/weighted_sampler.h"

#include <cstdint>
#include <random>
#include <cassert>


Expand Down Expand Up @@ -153,8 +153,8 @@ void CPUWeightedSampler::WeightedSample(const int64_t* col_begin,
// with replacement
const auto cap = col_end - col_begin;
if (req_num < cap) {
thread_local static std::random_device rd;
thread_local static std::mt19937 engine(rd());
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::discrete_distribution<> dist(prob_begin, prob_end);
for (int32_t i = 0; i < req_num; ++i) {
out_nbrs[i] = col_begin[dist(engine)];
Expand All @@ -176,8 +176,8 @@ void CPUWeightedSampler::WeightedSample(const int64_t* col_begin,
// with replacement
const auto cap = col_end - col_begin;
if (req_num < cap) {
thread_local static std::random_device rd;
thread_local static std::mt19937 engine(rd());
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::discrete_distribution<> dist(prob_begin, prob_end);
for (int32_t i = 0; i < req_num; ++i) {
auto idx = dist(engine);
Expand Down
6 changes: 3 additions & 3 deletions graphlearn_torch/csrc/cuda/random_negative_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "graphlearn_torch/include/common.h"
#include "graphlearn_torch/csrc/cuda/random_negative_sampler.cuh"

#include <ATen/cuda/CUDAContext.h>
Expand All @@ -21,7 +22,6 @@ limitations under the License.
#include <c10/cuda/CUDAStream.h>
#include <curand.h>
#include <curand_kernel.h>
#include <random>
#include <thrust/copy.h>
#include <thrust/gather.h>
#include <thrust/iterator/counting_iterator.h>
Expand Down Expand Up @@ -135,8 +135,8 @@ CUDARandomNegativeSampler::Sample(int32_t req_num,

int block_size = 0;
int grid_size = 0;
thread_local static std::random_device rd;
thread_local static std::mt19937 engine(rd());
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::uniform_int_distribution<int64_t> dist(0, 1e10);
cudaOccupancyMaxPotentialBlockSize(
&grid_size, &block_size, RandomNegativeSampleKernel);
Expand Down
10 changes: 5 additions & 5 deletions graphlearn_torch/csrc/cuda/random_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "graphlearn_torch/include/common.h"
#include "graphlearn_torch/csrc/cuda/random_sampler.cuh"

#include <c10/cuda/CUDAGuard.h>
Expand All @@ -21,7 +22,6 @@ limitations under the License.
#include <cub/cub.cuh>
#include <curand.h>
#include <curand_kernel.h>
#include <random>

#include "graphlearn_torch/include/common.cuh"

Expand Down Expand Up @@ -233,8 +233,8 @@ void CSRRowWiseSample(cudaStream_t stream,
int64_t* out_nbrs) {
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
thread_local static std::random_device rd;
thread_local static std::mt19937 engine(rd());
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::uniform_int_distribution<int64_t> dist(0, 1e10);
CSRRowWiseSampleKernel<<<grid, block, 0, stream>>>(
dist(engine), req_num, bs, row_count,
Expand All @@ -255,8 +255,8 @@ void CSRRowWiseSample(cudaStream_t stream,
int64_t* out_eid) {
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
thread_local static std::random_device rd;
thread_local static std::mt19937 engine(rd());
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::uniform_int_distribution<int64_t> dist(0, 1e10);
CSRRowWiseSampleKernel<<<grid, block, 0, stream>>>(
dist(engine), req_num, bs, row_count, nodes, row_ptr, col_idx, edge_ids,
Expand Down
33 changes: 33 additions & 0 deletions graphlearn_torch/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef GRAPHLEARN_TORCH_INCLUDE_COMMON_H_
#define GRAPHLEARN_TORCH_INCLUDE_COMMON_H_

#include <random>
#include <stdexcept>

namespace graphlearn_torch
Expand All @@ -32,6 +33,38 @@ inline void CheckEq(const T &x, const T &y) {
throw std::runtime_error(std::string("CheckEq failed"));
}

class RandomSeedManager {
public:
static RandomSeedManager& getInstance() {
static RandomSeedManager instance;
return instance;
}

void setSeed(uint32_t seed) {
this->is_set = true;
this->seed = seed;
}

uint32_t getSeed() const {
if (this->is_set) {
return seed;
}
else {
std::random_device rd;
return rd();
}
}

private:
RandomSeedManager() {} // Constructor is private
RandomSeedManager(RandomSeedManager const&) = delete; // Prevent copies
void operator=(RandomSeedManager const&) = delete; // Prevent assignments

uint32_t seed;
bool is_set = false;
};


} // namespace graphlearn_torch

#endif // GRAPHLEARN_TORCH_INCLUDE_COMMON_H_
3 changes: 2 additions & 1 deletion graphlearn_torch/python/distributed/dist_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self,
edge_dir: Literal['in', 'out'] = 'out',
collect_features: bool = False,
to_device: Optional[torch.device] = None,
random_seed: int = None,
worker_options: Optional[AllDistSamplingWorkerOptions] = None):

if isinstance(input_nodes, tuple):
Expand All @@ -109,7 +110,7 @@ def __init__(self,
sampling_config = SamplingConfig(
SamplingType.NODE, num_neighbors, batch_size, shuffle,
drop_last, with_edge, collect_features, with_neg=False,
with_weight=with_weight, edge_dir=edge_dir
with_weight=with_weight, edge_dir=edge_dir, seed=random_seed
)

super().__init__(
Expand Down
Loading

0 comments on commit 3f7731d

Please sign in to comment.