-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4a760a5
commit 265b714
Showing
6 changed files
with
181 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Using GPT to Reason on Graphs | ||
|
||
This simple example shows how to leverage GPT to make inference on large graphs. | ||
|
||
### 1. Prepare the graph dataset & environment | ||
In this example, we use the dataset | ||
[arxiv_2023](https://github.com/TRAIS-Lab/LLM-Structured-Data/tree/main/dataset/arxiv_2023) | ||
and download it to the path `../data/arxiv_2023`. | ||
|
||
Then, export your OPENAI_API_KEY as the environment variable in your shell: | ||
|
||
```bash | ||
export OPENAI_API_KEY='YOUR_API_KEY' | ||
``` | ||
|
||
### 2. Run the code | ||
Configure the parameters in the data loader and run the code. | ||
```bash | ||
python arxiv.py | ||
``` | ||
This example tests the inference performance of GPT on a large graph with the link prediction task as the default task. | ||
|
||
First, we sample a 2-hop ego-subgraph from the original graph. The subgraph is sampled by the `LinkNeighborLoader` with a mini-batch sampler that samples a fixed number of neighbors for each edge and is formed as PyG's `edge_index`. | ||
|
||
Then, the sampled subgraph along with node features (e.g. in this example the title for the paper node) is fed into GPT to infer whether a requested edge is in the original graph or not. | ||
|
||
**Note**: GPT has a limitation on its context size, and thus limits the size of the sampled subgraph, which is determined by the parameter `num_neighbors` in the data loader. If the sampled subgraph is too large, please try to decrease the `num_neighbors` to reduce the size of the subgraph. | ||
|
||
### Appendix: | ||
1. **Dataset**: You can also use other datasets and modify the preprocessing code, but don't forget to transform the graph format into PyG's `edge_index`. | ||
2. **Prompts**: Use the parameter `reason: Bool` to decide whether to see the reasoning process of GPT. You can also design your own prompts to make inference on graphs instead of using our template. | ||
3. **Node classification**: We also provide a template prompt for node classification task, and design your method to leverage the label informatiion. | ||
**Note**: We've tried directly passing the labels for nodes in an ego-subgraph to predict the label of the center node, and GPT's prediction behavior in this case is close to voting via neighbors. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import time | ||
import torch | ||
|
||
from tqdm import tqdm | ||
|
||
import graphlearn_torch as glt | ||
from utils import get_gpt_response, link_prediction | ||
|
||
|
||
def run(glt_ds, raw_text, reason): | ||
neg_sampling = glt.sampler.NegativeSampling('binary') | ||
train_loader = glt.loader.LinkNeighborLoader(glt_ds, | ||
[12, 6], | ||
neg_sampling=neg_sampling, | ||
batch_size=2, | ||
drop_last=True, | ||
shuffle=True, | ||
device=torch.device('cpu')) | ||
print(f'Building graphlearn_torch NeighborLoader Done.') | ||
|
||
for batch in tqdm(train_loader): | ||
batch_titles = raw_text[batch.node] | ||
if batch.edge_index.shape[1] < 5: | ||
continue | ||
|
||
# print(batch) | ||
# print(batch.edge_label_index) | ||
message = link_prediction(batch, batch_titles, reason=reason) | ||
|
||
# print(message) | ||
response = get_gpt_response( | ||
message=message | ||
) | ||
|
||
print(f"response: {response}") | ||
|
||
|
||
if __name__ == '__main__': | ||
import pandas as pd | ||
root = '../data/arxiv_2023/raw/' | ||
titles = pd.read_csv(root + "titles.csv.gz").to_numpy() | ||
ids = torch.from_numpy(pd.read_csv(root + "ids.csv.gz").to_numpy()) | ||
edge_index = torch.from_numpy(pd.read_csv(root + "edges.csv.gz").to_numpy()) | ||
|
||
print('Build graphlearn_torch dataset...') | ||
start = time.time() | ||
glt_dataset = glt.data.Dataset() | ||
glt_dataset.init_graph( | ||
edge_index=edge_index.T, | ||
graph_mode='CPU', | ||
directed=True | ||
) | ||
glt_dataset.init_node_features( | ||
node_feature_data=ids, | ||
sort_func=glt.data.sort_by_in_degree, | ||
split_ratio=0 | ||
) | ||
|
||
print(f'Build graphlearn_torch csr_topo and feature cost {time.time() - start} s.') | ||
|
||
run(glt_dataset, titles, reason=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from openai import OpenAI | ||
|
||
|
||
def get_gpt_response(message, model="gpt-4-1106-preview"): | ||
client = OpenAI() | ||
chat_completion = client.chat.completions.create( | ||
messages=[ | ||
{ | ||
"role" : "user", | ||
"content": message, | ||
} | ||
], | ||
model=model, | ||
) | ||
return chat_completion.choices[0].message.content | ||
|
||
|
||
def node_classification(batch): | ||
message = "This is a directed subgraph of arxiv citation network with " + str(batch.x.shape[0]) + " nodes numbered from 0 to " + str(batch.x.shape[0]-1) + ".\n" | ||
message += "The subgraph has " + str(batch.edge_index.shape[1]) + " edges.\n" | ||
for i in range(1, batch.x.shape[0]): | ||
feature_str = ','.join(f'{it:.3f}' for it in batch.x[i].tolist()) | ||
message += "The feature of node " + str(i) + " is [" + feature_str + "] " | ||
message += "and the node label is " + str(batch.y[i].item()) + ".\n" | ||
message += "The edges of the subgraph are " + str(batch.edge_index.T.tolist()) + ' where the first number indicates source node and the second destination node.\n' | ||
message += "Question: predict the label for node 0, whose feature is [" + ','.join(f'{it:.3f}' for it in batch.x[0].tolist()) + "]. Give the label only and don't show any reasoning process.\n\n" | ||
|
||
return message | ||
|
||
|
||
def link_prediction(batch, titles, reason=False): | ||
message = "This is a directed subgraph of arxiv citation network with " + str(batch.x.shape[0]) + " nodes numbered from 0 to " + str(batch.x.shape[0]-1) + ".\n" | ||
graph = batch.edge_index.T.unique(dim=0).tolist() | ||
message += "The titles of each paper:\n" | ||
for i in range(batch.x.shape[0]): | ||
message += "node " + str(i) + " is '" + titles[i][0] + "'\n" | ||
message += "The sampled subgraph of the network is " + str(graph) + ' where the first number indicates source node and the second destination node.\n' | ||
message += "Hint: the direction of the edge can indicate some information of temporality.\n" | ||
message += "\nAccording to principles of citation network construction and the given subgraph structure, answer the following questions:\n" | ||
|
||
# In batch.edge_label_index.T.tolist(), index 0 and 1 are positive samples, | ||
# index 2 and 3 are negative samples. | ||
message += "Question 1: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[1])+".\n" | ||
message += "Question 2: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[3])+".\n" | ||
message += "Question 3: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[2])+".\n" | ||
message += "Question 4: predict whether there tends to form an edge "+str(batch.edge_label_index.T.tolist()[0])+".\n" | ||
if reason: | ||
message += "Answer yes or no and show reasoning process.\n\n" | ||
else: | ||
message += "Answer yes or no and don't show any reasoning process.\n\n" | ||
|
||
return message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters