Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpt example #115

Merged
merged 10 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/gpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Using GPT to Reason on Graphs

This simple example shows how to leverage GPT to make inference on large graphs.

### 1. Prepare the graph dataset & environment
In this example, we use the dataset
[arxiv_2023](https://github.com/TRAIS-Lab/LLM-Structured-Data/tree/main/dataset/arxiv_2023)
and download it to the path `../data/arxiv_2023`.

Then, export your OPENAI_API_KEY as the environment variable in your shell:

```bash
export OPENAI_API_KEY='YOUR_API_KEY'
```

### 2. Run the code
Configure the parameters in the data loader and run the code.
```bash
python arxiv.py
```
This example tests the inference performance of GPT on a large graph with the link prediction task as the default task.

First, we sample a 2-hop ego-subgraph from the original graph. The subgraph is sampled by the `LinkNeighborLoader` with a mini-batch sampler that samples a fixed number of neighbors for each edge and is formed as PyG's `edge_index`.

Then, the sampled subgraph along with node features (e.g. in this example the title for the paper node) is fed into GPT to infer whether a requested edge is in the original graph or not.

**Note**: GPT has a limitation on its context size, and thus limits the size of the sampled subgraph, which is determined by the parameter `num_neighbors` in the data loader. If the sampled subgraph is too large, please try to decrease the `num_neighbors` to reduce the size of the subgraph.

### Appendix:
1. **Dataset**: You can also use other datasets and modify the preprocessing code, but don't forget to transform the graph format into PyG's `edge_index`.
2. **Prompts**: Use the parameter `reason: Bool` to decide whether to see the reasoning process of GPT. You can also design your own prompts to make inference on graphs instead of using our template.
3. **Node classification**: We also provide a template prompt for node classification task, and design your method to leverage the label informatiion.
**Note**: We've tried directly passing the labels for nodes in an ego-subgraph to predict the label of the center node, and GPT's prediction behavior in this case is close to voting via neighbors.
76 changes: 76 additions & 0 deletions examples/gpt/arxiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import time
import torch

from tqdm import tqdm

import graphlearn_torch as glt
from utils import get_gpt_response, link_prediction


def run(glt_ds, raw_text, reason):
neg_sampling = glt.sampler.NegativeSampling('binary')
train_loader = glt.loader.LinkNeighborLoader(glt_ds,
[12, 6],
neg_sampling=neg_sampling,
batch_size=2,
drop_last=True,
shuffle=True,
device=torch.device('cpu'))
print(f'Building graphlearn_torch NeighborLoader Done.')

for batch in tqdm(train_loader):
batch_titles = raw_text[batch.node]
if batch.edge_index.shape[1] < 5:
continue

# print(batch)
# print(batch.edge_label_index)
message = link_prediction(batch, batch_titles, reason=reason)

# print(message)
response = get_gpt_response(
message=message
)

print(f"response: {response}")


if __name__ == '__main__':
import pandas as pd
root = '../data/arxiv_2023/raw/'
titles = pd.read_csv(root + "titles.csv.gz").to_numpy()
ids = torch.from_numpy(pd.read_csv(root + "ids.csv.gz").to_numpy())
edge_index = torch.from_numpy(pd.read_csv(root + "edges.csv.gz").to_numpy())

print('Build graphlearn_torch dataset...')
start = time.time()
glt_dataset = glt.data.Dataset()
glt_dataset.init_graph(
edge_index=edge_index.T,
graph_mode='CPU',
husimplicity marked this conversation as resolved.
Show resolved Hide resolved
directed=True
)
glt_dataset.init_node_features(
node_feature_data=ids,
sort_func=glt.data.sort_by_in_degree,
split_ratio=0
)

print(f'Build graphlearn_torch csr_topo and feature cost {time.time() - start} s.')

run(glt_dataset, titles, reason=False)
67 changes: 67 additions & 0 deletions examples/gpt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from openai import OpenAI


def get_gpt_response(message, model="gpt-4-1106-preview"):
client = OpenAI()
chat_completion = client.chat.completions.create(
messages=[
{
"role" : "user",
"content": message,
}
],
model=model,
)
return chat_completion.choices[0].message.content


def node_classification(batch):
message = "This is a directed subgraph of arxiv citation network with " + str(batch.x.shape[0]) + " nodes numbered from 0 to " + str(batch.x.shape[0]-1) + ".\n"
message += "The subgraph has " + str(batch.edge_index.shape[1]) + " edges.\n"
for i in range(1, batch.x.shape[0]):
feature_str = ','.join(f'{it:.3f}' for it in batch.x[i].tolist())
message += "The feature of node " + str(i) + " is [" + feature_str + "] "
message += "and the node label is " + str(batch.y[i].item()) + ".\n"
message += "The edges of the subgraph are " + str(batch.edge_index.T.tolist()) + ' where the first number indicates source node and the second destination node.\n'
message += "Question: predict the label for node 0, whose feature is [" + ','.join(f'{it:.3f}' for it in batch.x[0].tolist()) + "]. Give the label only and don't show any reasoning process.\n\n"

return message


def link_prediction(batch, titles, reason=False):
message = "This is a directed subgraph of arxiv citation network with " + str(batch.x.shape[0]) + " nodes numbered from 0 to " + str(batch.x.shape[0]-1) + ".\n"
graph = batch.edge_index.T.unique(dim=0).tolist()
message += "The titles of each paper:\n"
for i in range(batch.x.shape[0]):
message += "node " + str(i) + " is '" + titles[i][0] + "'\n"
message += "The sampled subgraph of the network is " + str(graph) + ' where the first number indicates source node and the second destination node.\n'
message += "Hint: the direction of the edge can indicate some information of temporality.\n"
message += "\nAccording to principles of citation network construction and the given subgraph structure, answer the following questions:\n"

# In batch.edge_label_index.T.tolist(), index 0 and 1 are positive samples,
# index 2 and 3 are negative samples.
message += "Question 1: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[1])+".\n"
husimplicity marked this conversation as resolved.
Show resolved Hide resolved
message += "Question 2: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[3])+".\n"
message += "Question 3: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[2])+".\n"
message += "Question 4: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[0])+".\n"
if reason:
message += "Answer yes or no and show reasoning process.\n\n"
else:
message += "Answer yes or no and don't show any reasoning process.\n\n"

return message
7 changes: 4 additions & 3 deletions graphlearn_torch/python/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
self._inducer = None
self._sampler_lock = threading.Lock()
self.is_sampler_initialized = False
self.is_neg_sampler_initialized = False

if seed is not None:
pywrap.RandomSeedManager.getInstance().setSeed(seed)
Expand Down Expand Up @@ -113,7 +114,7 @@ def lazy_init_sampler(self):


def lazy_init_neg_sampler(self):
if not self.is_sampler_initialized and self.with_neg:
if not self.is_neg_sampler_initialized and self.with_neg:
with self._sampler_lock:
if self._neg_sampler is None:
if self._g_cls == 'homo':
Expand All @@ -122,7 +123,7 @@ def lazy_init_neg_sampler(self):
mode=self.device.type.upper(),
edge_dir=self.edge_dir
)
self.is_sampler_initialized = True
self.is_neg_sampler_initialized = True
else: # hetero
self._neg_sampler = {}
for etype, g in self.graph.items():
Expand All @@ -131,7 +132,7 @@ def lazy_init_neg_sampler(self):
mode=self.device.type.upper(),
edge_dir=self.edge_dir
)
self.is_sampler_initialized = True
self.is_neg_sampler_initialized = True

def lazy_init_subgraph_op(self):
if self._subgraph_op is None:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
WITH_CUDA = os.getenv('WITH_CUDA', 'ON')

sys.path.append(os.path.join(ROOT_PATH, 'graphlearn_torch', 'python', 'utils'))
from build import glt_ext_module, glt_v6d_ext_module
from build_glt import glt_ext_module, glt_v6d_ext_module

GLT_V6D_EXT_NAME = "py_graphlearn_torch_vineyard"
GLT_EXT_NAME = "py_graphlearn_torch"
Expand Down
Loading