diff --git a/examples/igbh/README.md b/examples/igbh/README.md index b1bbebd5..b4769517 100644 --- a/examples/igbh/README.md +++ b/examples/igbh/README.md @@ -26,21 +26,42 @@ bash download_igbh_large.sh ``` For the `tiny`, `small` or `medium` dataset, the download procedure is included -in the training script below. - -Note that in `dataset.py`, we have converted the graph into an undirected graph. +in the training script below. Note that in `dataset.py`, we have converted the graph i +nto an undirected graph. ## 2. Single node training: ``` python train_rgnn.py --model='rgat' --dataset_size='tiny' --num_classes=19 ``` -The script uses a single GPU, please add `--cpu_mode` if you want to use CPU only. +The script uses a single GPU, please add `--cpu_mode` if you want to use CPU only. +To save the memory costs while training large datasets, add `--use_fp16` to store +feature data in FP16 format. Option `--pin_feature` decides if the feature data will be +pinned in host memory, which enables zero-copy feature access from GPU but will +incur extra memory costs. + +To train the model using multiple GPUs using FP16 format wihtout pinning the feature: +``` +CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='tiny' --num_classes=19 --use_fp16 +``` + +Note that the original graph is in COO fornat, the above scripts will transform +the graph from COO to CSC or CSR according to the edge direction of sampling. This process +is time consuming when the graph is large. We provide a script to convert and persist +the graph in CSC or CSR format: +``` +python compress_graph.py --dataset_size='tiny' --layout='CSC' +``` + +Once the CSC or CSR is persisted, train the model with `--cpu_mode='CSC'` +or `--cpu_mode='CSR'`. -To train the model using multiple GPUs: ``` -CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='tiny' --num_classes=19 +CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='tiny' --num_classes=19 --use_fp16 --layout='CSC' ``` +Note that, when the sampling edge direction is `in`, the layout should be `CSC`. When the sampling edge direction is `out`, the layout should be `CSR`. + + ## 3. Distributed (multi nodes) examples We use 2 nodes as an example. diff --git a/examples/igbh/compress_graph.py b/examples/igbh/compress_graph.py new file mode 100644 index 00000000..e1a588b6 --- /dev/null +++ b/examples/igbh/compress_graph.py @@ -0,0 +1,128 @@ +# Copyright 2023 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import argparse, datetime, os +import numpy as np +import torch +import os.path as osp + +import graphlearn_torch as glt + +from download import download_dataset +from torch_geometric.utils import add_self_loops, remove_self_loops +from typing import Literal + + +class IGBHeteroDatasetCompress(object): + def __init__(self, + path, + dataset_size, + layout: Literal['CSC', 'CSR'] = 'CSC',): + self.dir = path + self.dataset_size = dataset_size + self.layout = layout + + self.ntypes = ['paper', 'author', 'institute', 'fos'] + self.etypes = None + self.edge_dict = {} + self.paper_nodes_num = {'tiny':100000, 'small':1000000, 'medium':10000000, 'large':100000000, 'full':269346174} + self.author_nodes_num = {'tiny':357041, 'small':1926066, 'medium':15544654, 'large':116959896, 'full':277220883} + if not osp.exists(osp.join(path, self.dataset_size, 'processed')): + download_dataset(path, 'heterogeneous', dataset_size) + self.process() + + def process(self): + paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__cites__paper', 'edge_index.npy'))).t() + author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__written_by__author', 'edge_index.npy'))).t() + affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'author__affiliated_to__institute', 'edge_index.npy'))).t() + paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__topic__fos', 'edge_index.npy'))).t() + if self.dataset_size in ['large', 'full']: + paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__published__journal', 'edge_index.npy'))).t() + paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__venue__conference', 'edge_index.npy'))).t() + + cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0] + self.edge_dict = { + ('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])), + ('paper', 'written_by', 'author'): author_paper_edges, + ('author', 'affiliated_to', 'institute'): affiliation_author_edges, + ('paper', 'topic', 'fos'): paper_fos_edges, + ('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]), + ('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]), + ('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :]) + } + if self.dataset_size in ['large', 'full']: + self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal + self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference + self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :]) + self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :]) + self.etypes = list(self.edge_dict.keys()) + + # init graphlearn_torch Dataset. + edge_dir = 'out' if self.layout == 'CSR' else 'in' + glt_dataset = glt.data.Dataset(edge_dir=edge_dir) + glt_dataset.init_graph( + edge_index=self.edge_dict, + graph_mode='CPU', + ) + + # save the corresponding csr or csc file + compress_edge_dict = {} + compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper' + compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author' + compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute' + compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos' + compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper' + compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author' + compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper' + compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal' + compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference' + compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper' + compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper' + + for etype in self.etypes: + graph = glt_dataset.get_graph(etype) + indptr, indices = graph.export_topology() + path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype]) + if not os.path.exists(path): + os.makedirs(path) + torch.save(indptr, os.path.join(path, 'indptr.pt')) + torch.save(indices, os.path.join(path, 'indices.pt')) + path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout) + print(f"The {self.layout} graph has been persisted in path: {path}") + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh') + glt.utils.ensure_dir(root) + parser.add_argument('--path', type=str, default=root, + help='path containing the datasets') + parser.add_argument('--dataset_size', type=str, default='tiny', + choices=['tiny', 'small', 'medium', 'large', 'full'], + help='size of the datasets') + parser.add_argument("--layout", type=str, default='CSC') + args = parser.parse_args() + print(f"Start constructing the {args.layout} graph...") + igbh_dataset = IGBHeteroDatasetCompress(args.path, args.dataset_size, args.layout) + + + + diff --git a/examples/igbh/dataset.py b/examples/igbh/dataset.py index 692e5e8d..e1df3b4b 100644 --- a/examples/igbh/dataset.py +++ b/examples/igbh/dataset.py @@ -21,6 +21,7 @@ from torch_geometric.utils import add_self_loops, remove_self_loops from download import download_dataset +from typing import Literal class IGBHeteroDataset(object): def __init__(self, @@ -28,12 +29,14 @@ def __init__(self, dataset_size='tiny', in_memory=True, use_label_2K=False, - with_edges=True): + with_edges=True, + layout: Literal['CSC', 'CSR', 'COO'] = 'COO',): self.dir = path self.dataset_size = dataset_size self.in_memory = in_memory self.use_label_2K = use_label_2K self.with_edges = with_edges + self.layout = layout self.ntypes = ['paper', 'author', 'institute', 'fos'] self.etypes = None @@ -52,50 +55,85 @@ def __init__(self, def process(self): if self.with_edges: - if self.in_memory: - paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__cites__paper', 'edge_index.npy'))).t() - author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__written_by__author', 'edge_index.npy'))).t() - affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'author__affiliated_to__institute', 'edge_index.npy'))).t() - paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__topic__fos', 'edge_index.npy'))).t() + if self.layout == 'COO': + if self.in_memory: + paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__cites__paper', 'edge_index.npy'))).t() + author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__written_by__author', 'edge_index.npy'))).t() + affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'author__affiliated_to__institute', 'edge_index.npy'))).t() + paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__topic__fos', 'edge_index.npy'))).t() + if self.dataset_size in ['large', 'full']: + paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__published__journal', 'edge_index.npy'))).t() + paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__venue__conference', 'edge_index.npy'))).t() + else: + paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__cites__paper', 'edge_index.npy'), mmap_mode='r')).t() + author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__written_by__author', 'edge_index.npy'), mmap_mode='r')).t() + affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'author__affiliated_to__institute', 'edge_index.npy'), mmap_mode='r')).t() + paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__topic__fos', 'edge_index.npy'), mmap_mode='r')).t() + if self.dataset_size in ['large', 'full']: + paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__published__journal', 'edge_index.npy'), mmap_mode='r')).t() + paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__venue__conference', 'edge_index.npy'), mmap_mode='r')).t() + + cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0] + self.edge_dict = { + ('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])), + ('paper', 'written_by', 'author'): author_paper_edges, + ('author', 'affiliated_to', 'institute'): affiliation_author_edges, + ('paper', 'topic', 'fos'): paper_fos_edges, + ('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]), + ('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]), + ('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :]) + } if self.dataset_size in ['large', 'full']: - paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__published__journal', 'edge_index.npy'))).t() - paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__venue__conference', 'edge_index.npy'))).t() + self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal + self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference + self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :]) + self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :]) + + # directly load from CSC or CSC files, which can be generated using compress_graph.py else: - paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__cites__paper', 'edge_index.npy'), mmap_mode='r')).t() - author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__written_by__author', 'edge_index.npy'), mmap_mode='r')).t() - affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'author__affiliated_to__institute', 'edge_index.npy'), mmap_mode='r')).t() - paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__topic__fos', 'edge_index.npy'), mmap_mode='r')).t() + compress_edge_dict = {} + compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper' + compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author' + compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute' + compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos' + compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper' + compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author' + compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper' if self.dataset_size in ['large', 'full']: - paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__published__journal', 'edge_index.npy'), mmap_mode='r')).t() - paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', - 'paper__venue__conference', 'edge_index.npy'), mmap_mode='r')).t() - - cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0] - self.edge_dict = { - ('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])), - ('paper', 'written_by', 'author'): author_paper_edges, - ('author', 'affiliated_to', 'institute'): affiliation_author_edges, - ('paper', 'topic', 'fos'): paper_fos_edges, - ('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]), - ('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]), - ('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :]) - } - if self.dataset_size in ['large', 'full']: - self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal - self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference - self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :]) - self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :]) + compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal' + compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference' + compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper' + compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper' + + for etype in compress_edge_dict.keys(): + edge_path = osp.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype]) + try: + edge_path = osp.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype]) + indptr = torch.load(osp.join(edge_path, 'indptr.pt')) + indices = torch.load(osp.join(edge_path, 'indices.pt')) + if self.layout == 'CSC': + self.edge_dict[etype] = (indices, indptr) + else: + self.edge_dict[etype] = (indptr, indices) + except FileNotFoundError: + print(f"FileNotFound: {file_path}") + exit() + except Exception as e: + print(f"Exception: {e}") + exit() + self.etypes = list(self.edge_dict.keys()) label_file = 'node_label_19.npy' if not self.use_label_2K else 'node_label_2K.npy' diff --git a/examples/igbh/dist_train_rgnn.py b/examples/igbh/dist_train_rgnn.py index 12acdac1..06f03e26 100644 --- a/examples/igbh/dist_train_rgnn.py +++ b/examples/igbh/dist_train_rgnn.py @@ -136,30 +136,6 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, ) ) - # Create distributed neighbor loader for testing. - test_idx = test_idx.split(test_idx.size(0) // num_training_procs)[local_proc_rank] - test_loader = glt.distributed.DistNeighborLoader( - data=dataset, - num_neighbors=[int(fanout) for fanout in fan_out.split(',')], - input_nodes=('paper', test_idx), - batch_size=batch_size, - shuffle=False, - edge_dir=edge_dir, - collect_features=True, - to_device=current_device, - worker_options = glt.distributed.MpDistSamplingWorkerOptions( - num_workers=1, - worker_devices=current_device, - worker_concurrency=4, - master_addr=master_addr, - master_port=test_loader_master_port, - channel_size='16GB', - pin_memory=True, - rpc_timeout=rpc_timeout, - num_rpc_threads=2 - ) - ) - # Define model and optimizer. if with_gpu: torch.cuda.set_device(current_device) @@ -224,7 +200,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, gpu_mem_alloc /= idx if with_gpu: torch.cuda.synchronize() - torch.distributed.barrier() + torch.distributed.barrier() if epoch%log_every == 0: model.eval() val_acc = evaluate(model, val_loader).item()*100 @@ -232,7 +208,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, best_accuracy = val_acc if with_gpu: torch.cuda.synchronize() - torch.distributed.barrier() + torch.distributed.barrier() tqdm.tqdm.write( "Rank{:02d} | Epoch {:03d} | Loss {:.4f} | Train Acc {:.2f} | Val Acc {:.2f} | Time {} | GPU {:.1f} MB".format( current_ctx.rank, @@ -244,13 +220,6 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, gpu_mem_alloc ) ) - - model.eval() - test_acc = evaluate(model, test_loader).item()*100 - print("Rank {:02d} Test Acc {:.2f}%".format(current_ctx.rank, test_acc)) - if with_gpu: - torch.cuda.synchronize() - torch.distributed.barrier() print("Total time taken " + str(datetime.timedelta(seconds = int(time.time() - training_start)))) @@ -274,7 +243,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, parser.add_argument('--fan_out', type=str, default='15,10,5') parser.add_argument('--batch_size', type=int, default=512) parser.add_argument('--hidden_channels', type=int, default=128) - parser.add_argument('--learning_rate', type=int, default=0.001) + parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--num_heads', type=int, default=4) @@ -330,12 +299,8 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, val_idx = torch.load( osp.join(args.path, f'{args.dataset_size}-val-partitions', f'partition{data_pidx}.pt') ) - test_idx = torch.load( - osp.join(args.path, f'{args.dataset_size}-test-partitions', f'partition{data_pidx}.pt') - ) train_idx.share_memory_() val_idx.share_memory_() - test_idx.share_memory_() print('--- Launching training processes ...\n') torch.multiprocessing.spawn( diff --git a/examples/igbh/train_rgnn_multi_gpu.py b/examples/igbh/train_rgnn_multi_gpu.py index 9e4a7fc5..5514ec7b 100644 --- a/examples/igbh/train_rgnn_multi_gpu.py +++ b/examples/igbh/train_rgnn_multi_gpu.py @@ -37,7 +37,13 @@ def evaluate(model, dataloader): with torch.no_grad(): for batch in dataloader: batch_size = batch['paper'].batch_size - out = model(batch.x_dict, batch.edge_index_dict)[:batch_size] + out = model( + { + node_name: node_feat.to(torch.float32) + for node_name, node_feat in batch.x_dict.items() + }, + batch.edge_index_dict + )[:batch_size] labels.append(batch['paper'].y[:batch_size].cpu().numpy()) predictions.append(out.argmax(1).cpu().numpy()) @@ -49,7 +55,7 @@ def evaluate(model, dataloader): def run_training_proc(rank, world_size, hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out, epochs, batch_size, learning_rate, log_every, - dataset, train_idx, val_idx, test_idx, with_gpu): + dataset, train_idx, val_idx, with_gpu): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' @@ -83,18 +89,6 @@ def run_training_proc(rank, world_size, device=current_device ) - # Create rank neighbor loader for testing. - test_idx = test_idx.split(test_idx.size(0) // world_size)[rank] - test_loader = glt.loader.NeighborLoader( - data=dataset, - num_neighbors=[int(fanout) for fanout in fan_out.split(',')], - input_nodes=('paper', test_idx), - batch_size=batch_size, - shuffle=True, - drop_last=False, - device=current_device - ) - # Define model and optimizer. model = RGNN(dataset.get_edge_types(), dataset.node_features['paper'].shape[1], @@ -134,7 +128,13 @@ def run_training_proc(rank, world_size, for batch in train_loader: idx += 1 batch_size = batch['paper'].batch_size - out = model(batch.x_dict, batch.edge_index_dict)[:batch_size] + out = model( + { + node_name: node_feat.to(torch.float32) + for node_name, node_feat in batch.x_dict.items() + }, + batch.edge_index_dict + )[:batch_size] y = batch['paper'].y[:batch_size] loss = loss_fcn(out, y) optimizer.zero_grad() @@ -152,7 +152,7 @@ def run_training_proc(rank, world_size, gpu_mem_alloc /= idx if with_gpu: torch.cuda.synchronize() - dist.barrier() + dist.barrier() if epoch%log_every == 0: model.eval() val_acc = evaluate(model, val_loader).item()*100 @@ -160,7 +160,7 @@ def run_training_proc(rank, world_size, best_accuracy = val_acc if with_gpu: torch.cuda.synchronize() - dist.barrier() + dist.barrier() tqdm.tqdm.write( "Rank{:02d} | Epoch {:03d} | Loss {:.4f} | Train Acc {:.2f} | Val Acc {:.2f} | Time {} | GPU {:.1f} MB".format( rank, @@ -172,10 +172,6 @@ def run_training_proc(rank, world_size, gpu_mem_alloc ) ) - - model.eval() - test_acc = evaluate(model, test_loader).item()*100 - print("Rank {:02d} Test Acc {:.2f}%".format(rank, test_acc)) print("Total time taken " + str(datetime.timedelta(seconds = int(time.time() - training_start)))) @@ -196,38 +192,51 @@ def run_training_proc(rank, world_size, parser.add_argument('--model', type=str, default='rgat', choices=['rgat', 'rsage']) # Model parameters - parser.add_argument('--fan_out', type=str, default='15,10,5') - parser.add_argument('--batch_size', type=int, default=5120) + parser.add_argument('--fan_out', type=str, default='15,10') + parser.add_argument('--batch_size', type=int, default=1024) parser.add_argument('--hidden_channels', type=int, default=128) - parser.add_argument('--learning_rate', type=int, default=0.01) + parser.add_argument('--learning_rate', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=20) - parser.add_argument('--num_layers', type=int, default=3) + parser.add_argument('--num_layers', type=int, default=2) parser.add_argument('--num_heads', type=int, default=4) parser.add_argument('--log_every', type=int, default=5) parser.add_argument("--cpu_mode", action="store_true", help="Only use CPU for sampling and training, default is False.") parser.add_argument("--edge_dir", type=str, default='in') + parser.add_argument('--layout', type=str, default='COO', + help="Layout of input graph. Default is COO.") + parser.add_argument("--pin_feature", action="store_true", + help="Pin the feature in host memory. Default is False.") + parser.add_argument("--use_fp16", action="store_true", + help="To use FP16 for loading the features. Default is False.") args = parser.parse_args() args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available() - + assert args.layout in ['COO', 'CSC', 'CSR'] igbh_dataset = IGBHeteroDataset(args.path, args.dataset_size, args.in_memory, - args.num_classes==2983) + args.num_classes==2983, True, args.layout) + if args.use_fp16: + for node_name, node_feat in igbh_dataset.feat_dict.items(): + igbh_dataset.feat_dict[node_name] = node_feat.half() + # init graphlearn_torch Dataset. glt_dataset = glt.data.Dataset(edge_dir=args.edge_dir) + + glt_dataset.init_node_features( + node_feature_data=igbh_dataset.feat_dict, + with_gpu=args.with_gpu and args.pin_feature + ) + glt_dataset.init_graph( edge_index=igbh_dataset.edge_dict, + layout = args.layout, graph_mode='ZERO_COPY' if args.with_gpu else 'CPU', ) + - glt_dataset.init_node_features( - node_feature_data=igbh_dataset.feat_dict, - with_gpu=True - ) glt_dataset.init_node_labels(node_label_data={'paper': igbh_dataset.label}) - train_idx = igbh_dataset.train_idx.share_memory_() - val_idx = igbh_dataset.val_idx.share_memory_() - test_idx = igbh_dataset.test_idx.share_memory_() + train_idx = igbh_dataset.train_idx.clone().share_memory_() + val_idx = igbh_dataset.val_idx.clone().share_memory_() print('--- Launching training processes ...\n') world_size = torch.cuda.device_count() @@ -236,7 +245,7 @@ def run_training_proc(rank, world_size, args=(world_size, args.hidden_channels, args.num_classes, args.num_layers, args.model, args.num_heads, args.fan_out, args.epochs, args.batch_size, args.learning_rate, args.log_every, - glt_dataset, train_idx, val_idx, test_idx, args.with_gpu), + glt_dataset, train_idx, val_idx, args.with_gpu), nprocs=world_size, join=True ) diff --git a/graphlearn_torch/python/data/graph.py b/graphlearn_torch/python/data/graph.py index c57a5376..6092f63f 100644 --- a/graphlearn_torch/python/data/graph.py +++ b/graphlearn_torch/python/data/graph.py @@ -251,6 +251,9 @@ def lazy_init(self): raise ValueError(f"'{self.__class__.__name__}': " f"invalid mode {self.mode}") + def export_topology(self): + return self.topo.indptr, self.topo.indices + def share_ipc(self): r""" Create ipc handle for multiprocessing.