Skip to content

Commit

Permalink
[GraphBolt] add node_attributes into FusedCSCSamplingGraph (#6757)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Dec 15, 2023
1 parent cad7cae commit e181ef1
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 51 deletions.
33 changes: 28 additions & 5 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
public:
using NodeTypeToIDMap = torch::Dict<std::string, int64_t>;
using EdgeTypeToIDMap = torch::Dict<std::string, int64_t>;
using NodeAttrMap = torch::Dict<std::string, torch::Tensor>;
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */
FusedCSCSamplingGraph() = default;
Expand All @@ -66,16 +67,18 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param node_attributes A dictionary of node attributes, if present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
*/
FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes);
const torch::optional<torch::Tensor>& node_type_offset = torch::nullopt,
const torch::optional<torch::Tensor>& type_per_edge = torch::nullopt,
const torch::optional<NodeTypeToIDMap>& node_type_to_id = torch::nullopt,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id = torch::nullopt,
const torch::optional<NodeAttrMap>& node_attributes = torch::nullopt,
const torch::optional<EdgeAttrMap>& edge_attributes = torch::nullopt);

/**
* @brief Create a fused CSC graph from tensors of CSC format.
Expand All @@ -89,6 +92,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param node_attributes A dictionary of node attributes, if present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
* @return FusedCSCSamplingGraph
Expand All @@ -99,6 +103,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes);

/** @brief Get the number of nodes. */
Expand Down Expand Up @@ -139,6 +144,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
return edge_type_to_id_;
}

/** @brief Get the node attributes dictionary. */
inline const torch::optional<EdgeAttrMap> NodeAttributes() const {
return node_attributes_;
}

/** @brief Get the edge attributes dictionary. */
inline const torch::optional<EdgeAttrMap> EdgeAttributes() const {
return edge_attributes_;
Expand Down Expand Up @@ -180,6 +190,12 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
edge_type_to_id_ = edge_type_to_id;
}

/** @brief Set the node attributes dictionary. */
inline void SetNodeAttributes(
const torch::optional<EdgeAttrMap>& node_attributes) {
node_attributes_ = node_attributes;
}

/** @brief Set the edge attributes dictionary. */
inline void SetEdgeAttributes(
const torch::optional<EdgeAttrMap>& edge_attributes) {
Expand Down Expand Up @@ -367,6 +383,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
*/
torch::optional<EdgeTypeToIDMap> edge_type_to_id_;

/**
* @brief A dictionary of node attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
* The length of each value should match the total number of nodes."
*/
torch::optional<NodeAttrMap> node_attributes_;

/**
* @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
Expand Down
46 changes: 44 additions & 2 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes)
: indptr_(indptr),
indices_(indices),
node_type_offset_(node_type_offset),
type_per_edge_(type_per_edge),
node_type_to_id_(node_type_to_id),
edge_type_to_id_(edge_type_to_id),
node_attributes_(node_attributes),
edge_attributes_(edge_attributes) {
TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1);
Expand All @@ -75,6 +77,7 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<NodeAttrMap>& node_attributes,
const torch::optional<EdgeAttrMap>& edge_attributes) {
if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value();
Expand All @@ -89,14 +92,19 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
TORCH_CHECK(edge_type_to_id.has_value());
}
if (node_attributes.has_value()) {
for (const auto& pair : node_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indptr.size(0) - 1);
}
}
if (edge_attributes.has_value()) {
for (const auto& pair : edge_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indices.size(0));
}
}
return c10::make_intrusive<FusedCSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge, node_type_to_id,
edge_type_to_id, edge_attributes);
edge_type_to_id, node_attributes, edge_attributes);
}

void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
Expand Down Expand Up @@ -150,6 +158,25 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
edge_type_to_id_ = std::move(edge_type_to_id);
}

// Optional node attributes.
torch::IValue has_node_attributes;
if (archive.try_read(
"FusedCSCSamplingGraph/has_node_attributes", has_node_attributes) &&
has_node_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/node_attributes")
.toGenericDict();
NodeAttrMap target_dict;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
torch::Tensor value = pair.value().toTensor();
// Use move to avoid copy.
target_dict.insert(std::move(key), std::move(value));
}
// Same as above.
node_attributes_ = std::move(target_dict);
}

// Optional edge attributes.
torch::IValue has_edge_attributes;
if (archive.try_read(
Expand Down Expand Up @@ -203,6 +230,13 @@ void FusedCSCSamplingGraph::Save(
archive.write(
"FusedCSCSamplingGraph/edge_type_to_id", edge_type_to_id_.value());
}
archive.write(
"FusedCSCSamplingGraph/has_node_attributes",
node_attributes_.has_value());
if (node_attributes_) {
archive.write(
"FusedCSCSamplingGraph/node_attributes", node_attributes_.value());
}
archive.write(
"FusedCSCSamplingGraph/has_edge_attributes",
edge_attributes_.has_value());
Expand Down Expand Up @@ -238,6 +272,9 @@ void FusedCSCSamplingGraph::SetState(
if (state.find("edge_type_to_id") != state.end()) {
edge_type_to_id_ = DetensorizeDict(state.at("edge_type_to_id"));
}
if (state.find("node_attributes") != state.end()) {
node_attributes_ = state.at("node_attributes");
}
if (state.find("edge_attributes") != state.end()) {
edge_attributes_ = state.at("edge_attributes");
}
Expand Down Expand Up @@ -268,6 +305,9 @@ FusedCSCSamplingGraph::GetState() const {
if (edge_type_to_id_.has_value()) {
state.insert("edge_type_to_id", TensorizeDict(edge_type_to_id_).value());
}
if (node_attributes_.has_value()) {
state.insert("node_attributes", node_attributes_.value());
}
if (edge_attributes_.has_value()) {
state.insert("edge_attributes", edge_attributes_.value());
}
Expand Down Expand Up @@ -596,10 +636,11 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto type_per_edge = helper.ReadTorchTensor();
auto node_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());
auto edge_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());
auto node_attributes = helper.ReadTorchTensorDict();
auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge,
node_type_to_id, edge_type_to_id, edge_attributes);
node_type_to_id, edge_type_to_id, node_attributes, edge_attributes);
auto shared_memory = helper.ReleaseSharedMemory();
graph->HoldSharedMemoryObject(
std::move(shared_memory.first), std::move(shared_memory.second));
Expand All @@ -616,6 +657,7 @@ FusedCSCSamplingGraph::CopyToSharedMemory(
helper.WriteTorchTensor(type_per_edge_);
helper.WriteTorchTensorDict(TensorizeDict(node_type_to_id_));
helper.WriteTorchTensorDict(TensorizeDict(edge_type_to_id_));
helper.WriteTorchTensorDict(node_attributes_);
helper.WriteTorchTensorDict(edge_attributes_);
helper.Flush();
return BuildGraphFromSharedMemoryHelper(std::move(helper));
Expand Down
7 changes: 1 addition & 6 deletions graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
TORCH_CHECK(
c10::isIntegralType(indices.scalar_type(), false),
"IndexSelectCSC is not implemented to slice noninteger types yet.");
torch::optional<torch::Tensor> temp;
torch::optional<sampling::FusedCSCSamplingGraph::NodeTypeToIDMap> temp2;
torch::optional<sampling::FusedCSCSamplingGraph::EdgeTypeToIDMap> temp3;
torch::optional<sampling::FusedCSCSamplingGraph::EdgeAttrMap> temp4;
sampling::FusedCSCSamplingGraph g(
indptr, indices, temp, temp, temp2, temp3, temp4);
sampling::FusedCSCSamplingGraph g(indptr, indices);
const auto res = g.InSubgraph(nodes);
return std::make_tuple(res->indptr, res->indices);
}
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ TORCH_LIBRARY(graphbolt, m) {
.def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
.def("node_type_to_id", &FusedCSCSamplingGraph::NodeTypeToID)
.def("edge_type_to_id", &FusedCSCSamplingGraph::EdgeTypeToID)
.def("node_attributes", &FusedCSCSamplingGraph::NodeAttributes)
.def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &FusedCSCSamplingGraph::SetIndices)
.def("set_node_type_offset", &FusedCSCSamplingGraph::SetNodeTypeOffset)
.def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
.def("set_node_type_to_id", &FusedCSCSamplingGraph::SetNodeTypeToID)
.def("set_edge_type_to_id", &FusedCSCSamplingGraph::SetEdgeTypeToID)
.def("set_node_attributes", &FusedCSCSamplingGraph::SetNodeAttributes)
.def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
Expand Down
35 changes: 34 additions & 1 deletion python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,27 @@ def edge_type_to_id(
"""Sets the edge type to id dictionary if present."""
self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)

@property
def node_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the node attributes dictionary.
Returns
-------
Dict[str, torch.Tensor] or None
If present, returns a dictionary of node attributes. Each key
represents the attribute's name, while the corresponding value
holds the attribute's specific value. The length of each value
should match the total number of nodes."
"""
return self._c_csc_graph.node_attributes()

@node_attributes.setter
def node_attributes(
self, node_attributes: Optional[Dict[str, torch.Tensor]]
) -> None:
"""Sets the node attributes dictionary."""
self._c_csc_graph.set_node_attributes(node_attributes)

@property
def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the edge attributes dictionary.
Expand Down Expand Up @@ -892,6 +913,9 @@ def _to(x, device):
self.type_per_edge = recursive_apply(
self.type_per_edge, lambda x: _to(x, device)
)
self.node_attributes = recursive_apply(
self.node_attributes, lambda x: _to(x, device)
)
self.edge_attributes = recursive_apply(
self.edge_attributes, lambda x: _to(x, device)
)
Expand All @@ -906,6 +930,7 @@ def fused_csc_sampling_graph(
type_per_edge: Optional[torch.tensor] = None,
node_type_to_id: Optional[Dict[str, int]] = None,
edge_type_to_id: Optional[Dict[str, int]] = None,
node_attributes: Optional[Dict[str, torch.tensor]] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None,
) -> FusedCSCSamplingGraph:
"""Create a FusedCSCSamplingGraph object from a CSC representation.
Expand All @@ -926,6 +951,8 @@ def fused_csc_sampling_graph(
Map node types to ids, by default None.
edge_type_to_id : Optional[Dict[str, int]], optional
Map edge types to ids, by default None.
node_attributes: Optional[Dict[str, torch.tensor]], optional
Node attributes of the graph, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None.
Expand All @@ -946,7 +973,7 @@ def fused_csc_sampling_graph(
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, edge_type_to_id=etypes,
... edge_attributes=None,)
... node_attributes=None, edge_attributes=None,)
>>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
Expand Down Expand Up @@ -997,6 +1024,7 @@ def fused_csc_sampling_graph(
type_per_edge,
node_type_to_id,
edge_type_to_id,
node_attributes,
edge_attributes,
),
)
Expand Down Expand Up @@ -1037,6 +1065,8 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
meta_str += f", node_type_to_id={graph.node_type_to_id}"
if graph.edge_type_to_id is not None:
meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
if graph.node_attributes is not None:
meta_str += f", node_attributes={graph.node_attributes}"
if graph.edge_attributes is not None:
meta_str += f", edge_attributes={graph.edge_attributes}"

Expand Down Expand Up @@ -1094,6 +1124,8 @@ def from_dglgraph(
# Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]

node_attributes = {}

edge_attributes = {}
if include_original_edge_id:
# Assign edge attributes according to the original eids mapping.
Expand All @@ -1107,6 +1139,7 @@ def from_dglgraph(
type_per_edge,
node_type_to_id,
edge_type_to_id,
node_attributes,
edge_attributes,
),
)
Loading

0 comments on commit e181ef1

Please sign in to comment.