Skip to content

Commit

Permalink
Add gpt example (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
husimplicity authored Jan 8, 2024
1 parent 4a760a5 commit 265b714
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 4 deletions.
33 changes: 33 additions & 0 deletions examples/gpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Using GPT to Reason on Graphs

This simple example shows how to leverage GPT to make inference on large graphs.

### 1. Prepare the graph dataset & environment
In this example, we use the dataset
[arxiv_2023](https://github.com/TRAIS-Lab/LLM-Structured-Data/tree/main/dataset/arxiv_2023)
and download it to the path `../data/arxiv_2023`.

Then, export your OPENAI_API_KEY as the environment variable in your shell:

```bash
export OPENAI_API_KEY='YOUR_API_KEY'
```

### 2. Run the code
Configure the parameters in the data loader and run the code.
```bash
python arxiv.py
```
This example tests the inference performance of GPT on a large graph with the link prediction task as the default task.

First, we sample a 2-hop ego-subgraph from the original graph. The subgraph is sampled by the `LinkNeighborLoader` with a mini-batch sampler that samples a fixed number of neighbors for each edge and is formed as PyG's `edge_index`.

Then, the sampled subgraph along with node features (e.g. in this example the title for the paper node) is fed into GPT to infer whether a requested edge is in the original graph or not.

**Note**: GPT has a limitation on its context size, and thus limits the size of the sampled subgraph, which is determined by the parameter `num_neighbors` in the data loader. If the sampled subgraph is too large, please try to decrease the `num_neighbors` to reduce the size of the subgraph.

### Appendix:
1. **Dataset**: You can also use other datasets and modify the preprocessing code, but don't forget to transform the graph format into PyG's `edge_index`.
2. **Prompts**: Use the parameter `reason: Bool` to decide whether to see the reasoning process of GPT. You can also design your own prompts to make inference on graphs instead of using our template.
3. **Node classification**: We also provide a template prompt for node classification task, and design your method to leverage the label informatiion.
**Note**: We've tried directly passing the labels for nodes in an ego-subgraph to predict the label of the center node, and GPT's prediction behavior in this case is close to voting via neighbors.
76 changes: 76 additions & 0 deletions examples/gpt/arxiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import time
import torch

from tqdm import tqdm

import graphlearn_torch as glt
from utils import get_gpt_response, link_prediction


def run(glt_ds, raw_text, reason):
neg_sampling = glt.sampler.NegativeSampling('binary')
train_loader = glt.loader.LinkNeighborLoader(glt_ds,
[12, 6],
neg_sampling=neg_sampling,
batch_size=2,
drop_last=True,
shuffle=True,
device=torch.device('cpu'))
print(f'Building graphlearn_torch NeighborLoader Done.')

for batch in tqdm(train_loader):
batch_titles = raw_text[batch.node]
if batch.edge_index.shape[1] < 5:
continue

# print(batch)
# print(batch.edge_label_index)
message = link_prediction(batch, batch_titles, reason=reason)

# print(message)
response = get_gpt_response(
message=message
)

print(f"response: {response}")


if __name__ == '__main__':
import pandas as pd
root = '../data/arxiv_2023/raw/'
titles = pd.read_csv(root + "titles.csv.gz").to_numpy()
ids = torch.from_numpy(pd.read_csv(root + "ids.csv.gz").to_numpy())
edge_index = torch.from_numpy(pd.read_csv(root + "edges.csv.gz").to_numpy())

print('Build graphlearn_torch dataset...')
start = time.time()
glt_dataset = glt.data.Dataset()
glt_dataset.init_graph(
edge_index=edge_index.T,
graph_mode='CPU',
directed=True
)
glt_dataset.init_node_features(
node_feature_data=ids,
sort_func=glt.data.sort_by_in_degree,
split_ratio=0
)

print(f'Build graphlearn_torch csr_topo and feature cost {time.time() - start} s.')

run(glt_dataset, titles, reason=False)
67 changes: 67 additions & 0 deletions examples/gpt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from openai import OpenAI


def get_gpt_response(message, model="gpt-4-1106-preview"):
client = OpenAI()
chat_completion = client.chat.completions.create(
messages=[
{
"role" : "user",
"content": message,
}
],
model=model,
)
return chat_completion.choices[0].message.content


def node_classification(batch):
message = "This is a directed subgraph of arxiv citation network with " + str(batch.x.shape[0]) + " nodes numbered from 0 to " + str(batch.x.shape[0]-1) + ".\n"
message += "The subgraph has " + str(batch.edge_index.shape[1]) + " edges.\n"
for i in range(1, batch.x.shape[0]):
feature_str = ','.join(f'{it:.3f}' for it in batch.x[i].tolist())
message += "The feature of node " + str(i) + " is [" + feature_str + "] "
message += "and the node label is " + str(batch.y[i].item()) + ".\n"
message += "The edges of the subgraph are " + str(batch.edge_index.T.tolist()) + ' where the first number indicates source node and the second destination node.\n'
message += "Question: predict the label for node 0, whose feature is [" + ','.join(f'{it:.3f}' for it in batch.x[0].tolist()) + "]. Give the label only and don't show any reasoning process.\n\n"

return message


def link_prediction(batch, titles, reason=False):
message = "This is a directed subgraph of arxiv citation network with " + str(batch.x.shape[0]) + " nodes numbered from 0 to " + str(batch.x.shape[0]-1) + ".\n"
graph = batch.edge_index.T.unique(dim=0).tolist()
message += "The titles of each paper:\n"
for i in range(batch.x.shape[0]):
message += "node " + str(i) + " is '" + titles[i][0] + "'\n"
message += "The sampled subgraph of the network is " + str(graph) + ' where the first number indicates source node and the second destination node.\n'
message += "Hint: the direction of the edge can indicate some information of temporality.\n"
message += "\nAccording to principles of citation network construction and the given subgraph structure, answer the following questions:\n"

# In batch.edge_label_index.T.tolist(), index 0 and 1 are positive samples,
# index 2 and 3 are negative samples.
message += "Question 1: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[1])+".\n"
message += "Question 2: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[3])+".\n"
message += "Question 3: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[2])+".\n"
message += "Question 4: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[0])+".\n"
if reason:
message += "Answer yes or no and show reasoning process.\n\n"
else:
message += "Answer yes or no and don't show any reasoning process.\n\n"

return message
7 changes: 4 additions & 3 deletions graphlearn_torch/python/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
self._inducer = None
self._sampler_lock = threading.Lock()
self.is_sampler_initialized = False
self.is_neg_sampler_initialized = False

if seed is not None:
pywrap.RandomSeedManager.getInstance().setSeed(seed)
Expand Down Expand Up @@ -113,7 +114,7 @@ def lazy_init_sampler(self):


def lazy_init_neg_sampler(self):
if not self.is_sampler_initialized and self.with_neg:
if not self.is_neg_sampler_initialized and self.with_neg:
with self._sampler_lock:
if self._neg_sampler is None:
if self._g_cls == 'homo':
Expand All @@ -122,7 +123,7 @@ def lazy_init_neg_sampler(self):
mode=self.device.type.upper(),
edge_dir=self.edge_dir
)
self.is_sampler_initialized = True
self.is_neg_sampler_initialized = True
else: # hetero
self._neg_sampler = {}
for etype, g in self.graph.items():
Expand All @@ -131,7 +132,7 @@ def lazy_init_neg_sampler(self):
mode=self.device.type.upper(),
edge_dir=self.edge_dir
)
self.is_sampler_initialized = True
self.is_neg_sampler_initialized = True

def lazy_init_subgraph_op(self):
if self._subgraph_op is None:
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
WITH_CUDA = os.getenv('WITH_CUDA', 'ON')

sys.path.append(os.path.join(ROOT_PATH, 'graphlearn_torch', 'python', 'utils'))
from build import glt_ext_module, glt_v6d_ext_module
from build_glt import glt_ext_module, glt_v6d_ext_module

GLT_V6D_EXT_NAME = "py_graphlearn_torch_vineyard"
GLT_EXT_NAME = "py_graphlearn_torch"
Expand Down

0 comments on commit 265b714

Please sign in to comment.