Skip to content

Commit

Permalink
import&None
Browse files Browse the repository at this point in the history
  • Loading branch information
Skeleton003 committed Feb 2, 2024
1 parent 4307cfd commit ae8db43
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/dgl/graphbolt/impl/ondisk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import torch
import yaml

import dgl.sparse as dglsp

from ...base import dgl_warning
from ...data.utils import download, extract_archive
from ..base import etype_str_to_tuple, ORIGINAL_EDGE_ID
Expand Down Expand Up @@ -62,6 +60,8 @@ def _graph_data_to_fused_csc_sampling_graph(
sampling_graph : FusedCSCSamplingGraph
The FusedCSCSamplingGraph constructed from the raw data.
"""
from ...sparse import spmatrix

is_homogeneous = (
len(graph_data["nodes"]) == 1
and len(graph_data["edges"]) == 1
Expand All @@ -77,7 +77,7 @@ def _graph_data_to_fused_csc_sampling_graph(
num_nodes = graph_data["nodes"][0]["num"]
num_edges = len(src)
coo_tensor = torch.tensor([src, dst])
sparse_matrix = dglsp.spmatrix(coo_tensor)
sparse_matrix = spmatrix(coo_tensor)
indptr, indices, value_indices = sparse_matrix.csc()
node_type_offset = None
type_per_edge = None
Expand Down Expand Up @@ -121,7 +121,7 @@ def _graph_data_to_fused_csc_sampling_graph(
coo_dst = torch.cat(coo_dst_list)
coo_etype = torch.cat(coo_etype_list)

Check warning on line 123 in python/dgl/graphbolt/impl/ondisk_dataset.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
sparse_matrix = dglsp.spmatrix(
sparse_matrix = spmatrix(
indices=torch.stack((coo_src, coo_dst), dim=0)
)
indptr, indices, value_indices = sparse_matrix.csc()
Expand Down Expand Up @@ -243,6 +243,11 @@ def _graph_data_to_fused_csc_sampling_graph(
] = feat
edge_attributes[feat_name] = feat_tensor

if not bool(node_attributes):
node_attributes = None
if not bool(edge_attributes):
edge_attributes = None

# Construct the FusedCSCSamplingGraph.
sampling_graph = fused_csc_sampling_graph(
csc_indptr=indptr,
Expand Down

0 comments on commit ae8db43

Please sign in to comment.