Skip to content

Commit

Permalink
Stanley's node_classification code for #7462
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanley (Guang) Yang committed Jul 24, 2024
1 parent c227b5a commit 8446b23
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 13 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# On stanyang branch
profile.svg
examples/graphbolt/utils.py
example.py

# IDE
.idea

Expand Down
93 changes: 82 additions & 11 deletions examples/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
import torchmetrics.functional as MF
from tqdm import tqdm

# gb.seed(123)
# torch.manual_seed(123)

Check warning on line 53 in examples/graphbolt/node_classification.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
def create_dataloader(
graph, features, itemset, batch_size, fanout, device, num_workers, job
graph, features, itemset, batch_size, fanout, device, num_workers, job, probs_name=None
):
"""
[HIGHLIGHT]
Expand Down Expand Up @@ -117,7 +119,7 @@ def create_dataloader(
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = getattr(datapipe, args.sample_mode)(
graph, fanout if job != "infer" else [-1]
graph, fanout if job != "infer" else [-1], probs_name=probs_name
)

############################################################################
Expand Down Expand Up @@ -228,8 +230,9 @@ def inference(self, graph, features, dataloader, storage_device):

@torch.no_grad()
def layerwise_infer(
args, graph, features, test_set, all_nodes_set, model, num_classes
args, graph, features, test_set, all_nodes_set, model, num_classes, probs_name=None
):
graph = graph.to(args.device)
model.eval()
dataloader = create_dataloader(
graph=graph,
Expand All @@ -240,6 +243,7 @@ def layerwise_infer(
device=args.device,
num_workers=args.num_workers,
job="infer",
probs_name=probs_name,
)
pred = model.inference(graph, features, dataloader, args.storage_device)
pred = pred[test_set._items[0]]
Expand All @@ -254,7 +258,8 @@ def layerwise_infer(


@torch.no_grad()
def evaluate(args, model, graph, features, itemset, num_classes):
def evaluate(args, model, graph, features, itemset, num_classes, probs_name=None):
graph = graph.to(args.device)
model.eval()
y = []
y_hats = []
Expand All @@ -267,8 +272,9 @@ def evaluate(args, model, graph, features, itemset, num_classes):
device=args.device,
num_workers=args.num_workers,
job="evaluate",
probs_name=probs_name,
)

# TODO: change the code in api or somewhere else ...
for step, data in tqdm(enumerate(dataloader), "Evaluating"):
x = data.node_features["feat"]
y.append(data.labels)
Expand All @@ -282,10 +288,23 @@ def evaluate(args, model, graph, features, itemset, num_classes):
)


def train(args, graph, features, train_set, valid_set, num_classes, model):
def train(args, graph, features, train_set, valid_set, num_classes, model, probs_name=None):
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=5e-4
)
# Sampling code

num_edges = graph.total_num_edges
if probs_name is not None:
prob_data = torch.rand(num_edges, device=args.device)
if probs_name == "weight":
prob_data[torch.randperm(num_edges, device=args.device)[: int(num_edges * 1)]] = 0.0
elif probs_name == "mask":
prob_data = prob_data > 0.2 # original: 0.2
graph.add_edge_attribute(probs_name, prob_data)
graph = graph.to(args.device)
print("In node_classification: self.edge_attributes = ", graph.edge_attribute(probs_name), "probs_name = ", probs_name)

dataloader = create_dataloader(
graph=graph,
features=features,
Expand All @@ -295,13 +314,57 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
device=args.device,
num_workers=args.num_workers,
job="train",
probs_name=probs_name,
)

for epoch in range(args.epochs):
t0 = time.time()
epoch_start_time = time.time()
model.train()
total_loss = 0
total_sampling_time = 0
total_batch_time = 0

# NOTE: uncomment the code below to get a new prob/mask data for each batch
# num_edges = graph.total_num_edges
# print("num_edges = ", num_edges)
# if probs_name is not None:
# prob_data = torch.rand(num_edges, device=args.device)
# if probs_name == "weight":
# prob_data[torch.randperm(num_edges, device=args.device)[: int(num_edges * 0.5)]] = 0.0
# # print(prob_data)
# elif probs_name == "mask":
# prob_data = prob_data > 0.2
# graph.add_edge_attribute(probs_name, prob_data)
# graph = graph.to(args.device)
# # print("(new epoch): In node_classification: self.edge_attributes = ", graph.edge_attribute(probs_name), "probs_name = ", probs_name)

# graph.add_edge_attribute(probs_name, prob_data_list[epoch])
# graph = graph.to(args.device)
# print("(new epoch): In node_classification: self.edge_attributes = ", graph.edge_attribute(probs_name), "probs_name = ", probs_name)
for step, data in tqdm(enumerate(dataloader), "Training"):

# NOTE: uncomment the code below to get a new prob/mask data for each iteration
# num_edges = graph.total_num_edges
# # print("num_edges = ", num_edges)
# if probs_name is not None:
# prob_data = torch.rand(num_edges, device=args.device)
# if probs_name == "weight":
# prob_data[torch.randperm(num_edges, device=args.device)[: int(num_edges * 0.5)]] = 0.0
# # print(prob_data)
# elif probs_name == "mask":
# prob_data = prob_data > 0.2
# graph.add_edge_attribute(probs_name, prob_data)
# graph = graph.to(args.device)
# # print("(new iteration): In node_classification: self.edge_attributes = ", graph.edge_attribute(probs_name), "probs_name = ", probs_name)

# Measure sampling time (time taken to fetch the next batch)
sampling_start_time = time.time()
if step > 0:
total_sampling_time += sampling_start_time - batch_end_time

# Measure batch processing time
batch_start_time = time.time()

# The input features from the source nodes in the first layer's
# computation graph.
x = data.node_features["feat"]
Expand All @@ -321,12 +384,17 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):

total_loss += loss.item()

t1 = time.time()
batch_end_time = time.time()
total_batch_time += batch_end_time - batch_start_time

epoch_end_time = time.time()
# Evaluate the model.
acc = evaluate(args, model, graph, features, valid_set, num_classes)
acc = evaluate(args, model, graph, features, valid_set, num_classes, probs_name=probs_name)
print(
f"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | "
f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
f"Accuracy {acc.item():.4f} | Time {epoch_end_time - epoch_start_time:.4f} | "
f"Average Sampling Time {total_sampling_time / (step + 1):.4f} | "
f"Average Batch Time {total_batch_time / (step + 1):.4f}"
)


Expand Down Expand Up @@ -427,9 +495,11 @@ def main(args):
assert len(args.fanout) == len(model.layers)
model = model.to(args.device)

probs_name_GLOBAL = None # options: None | "mask" | "weight"

# Model training.
print("Training...")
train(args, graph, features, train_set, valid_set, num_classes, model)
train(args, graph, features, train_set, valid_set, num_classes, model, probs_name=probs_name_GLOBAL)

# Test the model.
print("Testing...")
Expand All @@ -441,6 +511,7 @@ def main(args):
all_nodes_set,
model,
num_classes,
probs_name=probs_name_GLOBAL
)
print(f"Test accuracy {test_acc.item():.4f}")

Expand Down
4 changes: 2 additions & 2 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,15 +552,15 @@ def __init__(
graph,
fanouts,
replace=False,
prob_name=None,
probs_name=None,
deduplicate=True,
):
super().__init__(
datapipe,
graph,
fanouts,
replace,
prob_name,
probs_name,
deduplicate,
graph.sample_neighbors,
)
Expand Down

0 comments on commit 8446b23

Please sign in to comment.