Skip to content

Commit

Permalink
IGBH updates (#98)
Browse files Browse the repository at this point in the history
1. Removed the test process
2. Add `compress_graph.py` to persist CSR/CSC graph
3. Add flags to enable loading the persisted CSR/CSC graph in `train_rgnn_multi_gpu.py`
4. Add flags to control if feature will be pinned or not
5. Add flags to control is FP16 is used to store the feature data
  • Loading branch information
LiSu authored Nov 16, 2023
1 parent 40f4a55 commit 0cf542a
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 121 deletions.
33 changes: 27 additions & 6 deletions examples/igbh/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,42 @@ bash download_igbh_large.sh
```

For the `tiny`, `small` or `medium` dataset, the download procedure is included
in the training script below.

Note that in `dataset.py`, we have converted the graph into an undirected graph.
in the training script below. Note that in `dataset.py`, we have converted the graph i
nto an undirected graph.

## 2. Single node training:
```
python train_rgnn.py --model='rgat' --dataset_size='tiny' --num_classes=19
```
The script uses a single GPU, please add `--cpu_mode` if you want to use CPU only.
The script uses a single GPU, please add `--cpu_mode` if you want to use CPU only.
To save the memory costs while training large datasets, add `--use_fp16` to store
feature data in FP16 format. Option `--pin_feature` decides if the feature data will be
pinned in host memory, which enables zero-copy feature access from GPU but will
incur extra memory costs.

To train the model using multiple GPUs using FP16 format wihtout pinning the feature:
```
CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='tiny' --num_classes=19 --use_fp16
```

Note that the original graph is in COO fornat, the above scripts will transform
the graph from COO to CSC or CSR according to the edge direction of sampling. This process
is time consuming when the graph is large. We provide a script to convert and persist
the graph in CSC or CSR format:
```
python compress_graph.py --dataset_size='tiny' --layout='CSC'
```

Once the CSC or CSR is persisted, train the model with `--cpu_mode='CSC'`
or `--cpu_mode='CSR'`.

To train the model using multiple GPUs:
```
CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='tiny' --num_classes=19
CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='tiny' --num_classes=19 --use_fp16 --layout='CSC'
```

Note that, when the sampling edge direction is `in`, the layout should be `CSC`. When the sampling edge direction is `out`, the layout should be `CSR`.


## 3. Distributed (multi nodes) examples

We use 2 nodes as an example.
Expand Down
128 changes: 128 additions & 0 deletions examples/igbh/compress_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2023 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse, datetime, os
import numpy as np
import torch
import os.path as osp

import graphlearn_torch as glt

from download import download_dataset
from torch_geometric.utils import add_self_loops, remove_self_loops
from typing import Literal


class IGBHeteroDatasetCompress(object):
def __init__(self,
path,
dataset_size,
layout: Literal['CSC', 'CSR'] = 'CSC',):
self.dir = path
self.dataset_size = dataset_size
self.layout = layout

self.ntypes = ['paper', 'author', 'institute', 'fos']
self.etypes = None
self.edge_dict = {}
self.paper_nodes_num = {'tiny':100000, 'small':1000000, 'medium':10000000, 'large':100000000, 'full':269346174}
self.author_nodes_num = {'tiny':357041, 'small':1926066, 'medium':15544654, 'large':116959896, 'full':277220883}
if not osp.exists(osp.join(path, self.dataset_size, 'processed')):
download_dataset(path, 'heterogeneous', dataset_size)
self.process()

def process(self):
paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__cites__paper', 'edge_index.npy'))).t()
author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__written_by__author', 'edge_index.npy'))).t()
affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'author__affiliated_to__institute', 'edge_index.npy'))).t()
paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__topic__fos', 'edge_index.npy'))).t()
if self.dataset_size in ['large', 'full']:
paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__published__journal', 'edge_index.npy'))).t()
paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__venue__conference', 'edge_index.npy'))).t()

cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0]
self.edge_dict = {
('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])),
('paper', 'written_by', 'author'): author_paper_edges,
('author', 'affiliated_to', 'institute'): affiliation_author_edges,
('paper', 'topic', 'fos'): paper_fos_edges,
('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]),
('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]),
('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :])
}
if self.dataset_size in ['large', 'full']:
self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal
self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference
self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :])
self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :])
self.etypes = list(self.edge_dict.keys())

# init graphlearn_torch Dataset.
edge_dir = 'out' if self.layout == 'CSR' else 'in'
glt_dataset = glt.data.Dataset(edge_dir=edge_dir)
glt_dataset.init_graph(
edge_index=self.edge_dict,
graph_mode='CPU',
)

# save the corresponding csr or csc file
compress_edge_dict = {}
compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper'
compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author'
compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute'
compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos'
compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper'
compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author'
compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper'
compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal'
compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference'
compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper'
compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper'

for etype in self.etypes:
graph = glt_dataset.get_graph(etype)
indptr, indices = graph.export_topology()
path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype])
if not os.path.exists(path):
os.makedirs(path)
torch.save(indptr, os.path.join(path, 'indptr.pt'))
torch.save(indices, os.path.join(path, 'indices.pt'))
path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout)
print(f"The {self.layout} graph has been persisted in path: {path}")



if __name__ == '__main__':
parser = argparse.ArgumentParser()
root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh')
glt.utils.ensure_dir(root)
parser.add_argument('--path', type=str, default=root,
help='path containing the datasets')
parser.add_argument('--dataset_size', type=str, default='tiny',
choices=['tiny', 'small', 'medium', 'large', 'full'],
help='size of the datasets')
parser.add_argument("--layout", type=str, default='CSC')
args = parser.parse_args()
print(f"Start constructing the {args.layout} graph...")
igbh_dataset = IGBHeteroDatasetCompress(args.path, args.dataset_size, args.layout)




122 changes: 80 additions & 42 deletions examples/igbh/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,22 @@

from torch_geometric.utils import add_self_loops, remove_self_loops
from download import download_dataset
from typing import Literal

class IGBHeteroDataset(object):
def __init__(self,
path,
dataset_size='tiny',
in_memory=True,
use_label_2K=False,
with_edges=True):
with_edges=True,
layout: Literal['CSC', 'CSR', 'COO'] = 'COO',):
self.dir = path
self.dataset_size = dataset_size
self.in_memory = in_memory
self.use_label_2K = use_label_2K
self.with_edges = with_edges
self.layout = layout

self.ntypes = ['paper', 'author', 'institute', 'fos']
self.etypes = None
Expand All @@ -52,50 +55,85 @@ def __init__(self,

def process(self):
if self.with_edges:
if self.in_memory:
paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__cites__paper', 'edge_index.npy'))).t()
author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__written_by__author', 'edge_index.npy'))).t()
affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'author__affiliated_to__institute', 'edge_index.npy'))).t()
paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__topic__fos', 'edge_index.npy'))).t()
if self.layout == 'COO':
if self.in_memory:
paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__cites__paper', 'edge_index.npy'))).t()
author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__written_by__author', 'edge_index.npy'))).t()
affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'author__affiliated_to__institute', 'edge_index.npy'))).t()
paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__topic__fos', 'edge_index.npy'))).t()
if self.dataset_size in ['large', 'full']:
paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__published__journal', 'edge_index.npy'))).t()
paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__venue__conference', 'edge_index.npy'))).t()
else:
paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__cites__paper', 'edge_index.npy'), mmap_mode='r')).t()
author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__written_by__author', 'edge_index.npy'), mmap_mode='r')).t()
affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'author__affiliated_to__institute', 'edge_index.npy'), mmap_mode='r')).t()
paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__topic__fos', 'edge_index.npy'), mmap_mode='r')).t()
if self.dataset_size in ['large', 'full']:
paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__published__journal', 'edge_index.npy'), mmap_mode='r')).t()
paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__venue__conference', 'edge_index.npy'), mmap_mode='r')).t()

cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0]
self.edge_dict = {
('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])),
('paper', 'written_by', 'author'): author_paper_edges,
('author', 'affiliated_to', 'institute'): affiliation_author_edges,
('paper', 'topic', 'fos'): paper_fos_edges,
('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]),
('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]),
('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :])
}
if self.dataset_size in ['large', 'full']:
paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__published__journal', 'edge_index.npy'))).t()
paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__venue__conference', 'edge_index.npy'))).t()
self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal
self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference
self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :])
self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :])

# directly load from CSC or CSC files, which can be generated using compress_graph.py
else:
paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__cites__paper', 'edge_index.npy'), mmap_mode='r')).t()
author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__written_by__author', 'edge_index.npy'), mmap_mode='r')).t()
affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'author__affiliated_to__institute', 'edge_index.npy'), mmap_mode='r')).t()
paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__topic__fos', 'edge_index.npy'), mmap_mode='r')).t()
compress_edge_dict = {}
compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper'
compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author'
compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute'
compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos'
compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper'
compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author'
compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper'
if self.dataset_size in ['large', 'full']:
paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__published__journal', 'edge_index.npy'), mmap_mode='r')).t()
paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__venue__conference', 'edge_index.npy'), mmap_mode='r')).t()

cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0]
self.edge_dict = {
('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])),
('paper', 'written_by', 'author'): author_paper_edges,
('author', 'affiliated_to', 'institute'): affiliation_author_edges,
('paper', 'topic', 'fos'): paper_fos_edges,
('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]),
('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]),
('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :])
}
if self.dataset_size in ['large', 'full']:
self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal
self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference
self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :])
self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :])
compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal'
compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference'
compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper'
compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper'

for etype in compress_edge_dict.keys():
edge_path = osp.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype])
try:
edge_path = osp.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype])
indptr = torch.load(osp.join(edge_path, 'indptr.pt'))
indices = torch.load(osp.join(edge_path, 'indices.pt'))
if self.layout == 'CSC':
self.edge_dict[etype] = (indices, indptr)
else:
self.edge_dict[etype] = (indptr, indices)
except FileNotFoundError:
print(f"FileNotFound: {file_path}")
exit()
except Exception as e:
print(f"Exception: {e}")
exit()

self.etypes = list(self.edge_dict.keys())

label_file = 'node_label_19.npy' if not self.use_label_2K else 'node_label_2K.npy'
Expand Down
Loading

0 comments on commit 0cf542a

Please sign in to comment.