Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mostly cosmetic, adding some type hints, and clarifying comments #16

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion deepsnap/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __init__(self, batch=None, **kwargs):
self.__slices__ = None

@staticmethod
def collate(follow_batch=[], transform=None, **kwargs):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this pattern appears in a few places i think:
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Thanks for proposing this. It seems that it is not a bug because we want the users to pass in a list for follow_batch and if they don't pass in the list we will make the follow_batch an empty list. We are not going to append elements to the follow_batch as appears in the link.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, def not a bug from what i can tell so it's minor. however, this pattern can lead to very tricky issues to debug and is generally advised against so wanted to point it out. no worries if you prefer it the way it is.

def collate(follow_batch=None, transform=None, **kwargs):
if follow_batch is None:
follow_batch = []
return lambda batch: Batch.from_data_list(
batch, follow_batch, transform, **kwargs
)
Expand Down Expand Up @@ -70,6 +72,7 @@ def from_data_list(
batch, cumsum = Batch._init_batch_fields(keys, follow_batch)
batch.__data_class__ = data_list[0].__class__
batch.batch = []
num_nodes = None
for i, data in enumerate(data_list):
# Note: in heterogeneous graph, __inc__ logic is different
Batch._collate_dict(
Expand Down
27 changes: 13 additions & 14 deletions deepsnap/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,8 @@ def _split_transductive(
Split the dataset assuming training process is transductive.

Args:
split_ratio: number of data splitted into train, validation
(and test) set.
split_ratio: ratio of data to be split into train, validation
(and test) sets.

Returns:
list: A list of 3 (2) lists of :class:`deepsnap.graph.Graph`
Expand Down Expand Up @@ -978,11 +978,13 @@ def split(
transductive: whether the training process is transductive
or inductive. Inductive split is always used for graph-level
tasks (self.task == 'graph').
split_ratio: number of data splitted into train, validation
(and test) set.
split_ratio: ratio of data to be split into train, validation
(and test) sets. These ratios must sum to 1.
If `None` the default splits are 0.8, 0.1,
and 0.1 for train, validation, and test sets respectively.

Returns:
list: a list of 3 (2) lists of :class:`deepsnap.graph.Graph`
list: a list of 2 (or 3) lists of :class:`deepsnap.graph.Graph`
objects corresponding to train, validation (and test) set.
"""
if self.graphs is None:
Expand All @@ -1006,8 +1008,8 @@ def split(
for split_ratio_i in split_ratio
):
raise TypeError("Split ratio must contain all floats.")
if not all(split_ratio_i > 0 for split_ratio_i in split_ratio):
raise ValueError("Split ratio must contain all positivevalues.")
if not all(0 < split_ratio_i < 1 for split_ratio_i in split_ratio):
raise ValueError("Split ratios must be between 0 and 1.")

# store the most recent split types
self._split_types = split_types
Expand All @@ -1019,31 +1021,28 @@ def split(
graph.edge_label = graph._edge_label

# list of num_splits datasets
dataset_return = []
if transductive:
if self.task == "graph":
raise ValueError(
"in transductive mode, self.task is graph does not "
"in transductive mode, self.task == `graph` does not "
"make sense."
)
dataset_return = (
return (
self._split_transductive(
split_ratio, split_types, shuffle=shuffle
)
)
else:
dataset_return = (
return (
self._split_inductive(
split_ratio,
split_types,
shuffle=shuffle
)
)

return dataset_return

def resample_disjoint(self):
r""" Resample disjoint edge split of message passing and objective links.
r"""Resample disjoint edge split of message passing and objective links.

Note that if apply_transform (on the message passing graph)
was used before this resampling, it needs to be
Expand Down
21 changes: 12 additions & 9 deletions deepsnap/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from typing import (
Dict,
List,
Optional,
Union
)
import warnings
import deepsnap
import networkx as nx


class Graph(object):
Expand All @@ -23,13 +25,13 @@ class Graph(object):
Args:
G (:class:`networkx.classes.graph`): The NetworkX graph object which
contains features and labels for the tasks.
**kwargs: keyworded argument list with keys such
**kwargs: keyword argument list with keys such
as :obj:`"node_feature"`, :obj:`"node_label"` and
corresponding attributes.
"""

def __init__(self, G=None, netlib=None, **kwargs):
self.G = G
def __init__(self, G: Optional[nx.Graph] = None, netlib=None, **kwargs):
self.G: nx.Graph = G
if netlib is not None:
deepsnap._netlib = netlib
keys = [
Expand Down Expand Up @@ -303,7 +305,8 @@ def get_num_dims(self, key: str, as_label: bool = False) -> int:
Returns the number of dimensions for one graph/node/edge property.

Args:
as_label: if as_label, treat the tensor as labels (
key: the `key` to return the dimension for
as_label: if `as_label`, treat the tensor as labels
"""
if as_label:
# treat as label
Expand Down Expand Up @@ -1147,7 +1150,7 @@ def split(
else:
raise ValueError("Unknown task.")

def _split_node(self, split_ratio: float, shuffle: bool = True):
def _split_node(self, split_ratio: List[float], shuffle: bool = True):
r"""
Split the graph into len(split_ratio) graphs for node prediction.
Internally this splits node indices, and the model will only compute
Expand Down Expand Up @@ -1217,7 +1220,7 @@ def _split_node(self, split_ratio: float, shuffle: bool = True):
split_graphs.append(graph_new)
return split_graphs

def _split_edge(self, split_ratio: float, shuffle: bool = True):
def _split_edge(self, split_ratio: List[float], shuffle: bool = True):
r"""
Split the graph into len(split_ratio) graphs for node prediction.
Internally this splits node indices, and the model will only compute
Expand Down Expand Up @@ -1393,14 +1396,14 @@ def split_link_pred(
nodes in each split graph.
This is only used for transductive link prediction task
In this task, different part of graph is observed in train/val/test
Note: this functon will be called twice,
Note: this function will be called twice,
if during training, we further split the training graph so that
message edges and objective edges are different
"""
if isinstance(split_ratio, float):
split_ratio = [split_ratio, 1 - split_ratio]
if len(split_ratio) < 2 or len(split_ratio) > 3:
raise ValueError("Unrecoginzed number of splits")
raise ValueError("Unrecognized number of splits")
if self.num_edges < len(split_ratio):
raise ValueError(
"In _split_link_pred num of edges are smaller than"
Expand Down Expand Up @@ -1526,7 +1529,7 @@ def split_link_pred(
else:
return [graph_train, graph_val]

def _edge_subgraph_with_isonodes(self, G, edges):
def _edge_subgraph_with_isonodes(self, G: nx.Graph, edges: List):
r"""
Generate a new networkx graph with same nodes and their attributes.

Expand Down
3 changes: 2 additions & 1 deletion deepsnap/hetero_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (
Dict,
List,
Optional,
Union
)
import warnings
Expand Down Expand Up @@ -120,7 +121,7 @@ def message_types(self):
"""
return list(self["edge_index"].keys())

def num_nodes(self, node_type: Union[str, List[str]] = None):
def num_nodes(self, node_type: Optional[Union[str, List[str]]] = None) -> Dict:
r"""
Return number of nodes for a node type or list of node types.

Expand Down