Skip to content

Commit

Permalink
[Graphbolt] Add proper error message for BuildinDataset. (#6690)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
frozenbugs and Ubuntu authored Dec 6, 2023
1 parent 8a85253 commit cbb6f50
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/dgl/graphbolt/impl/ondisk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def __init__(
yaml_path = preprocess_ondisk_dataset(path, include_original_edge_id)
with open(yaml_path) as f:
self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
self._loaded = False

def _convert_yaml_path_to_absolute_path(self):
"""Convert the path in YAML file to absolute path."""
Expand Down Expand Up @@ -384,6 +385,7 @@ def load(self):
self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._tasks = self._init_tasks(self._meta.tasks)
self._all_nodes_set = self._init_all_nodes_set(self._graph)
self._loaded = True
return self

@property
Expand All @@ -394,26 +396,31 @@ def yaml_data(self) -> Dict:
@property
def tasks(self) -> List[Task]:
"""Return the tasks."""
self._check_loaded()
return self._tasks

@property
def graph(self) -> SamplingGraph:
"""Return the graph."""
self._check_loaded()
return self._graph

@property
def feature(self) -> TorchBasedFeatureStore:
"""Return the feature."""
self._check_loaded()
return self._feature

@property
def dataset_name(self) -> str:
"""Return the dataset name."""
self._check_loaded()
return self._dataset_name

@property
def all_nodes_set(self) -> Union[ItemSet, ItemSetDict]:
"""Return the itemset containing all nodes."""
self._check_loaded()
return self._all_nodes_set

def _init_tasks(self, tasks: List[OnDiskTaskData]) -> List[OnDiskTask]:
Expand All @@ -432,6 +439,12 @@ def _init_tasks(self, tasks: List[OnDiskTaskData]) -> List[OnDiskTask]:
)
return ret

def _check_loaded(self):
assert self._loaded, (
"Please ensure that you have called the OnDiskDataset.load() method"
+ " to properly load the data."
)

def _load_graph(
self, graph_topology: OnDiskGraphTopology
) -> FusedCSCSamplingGraph:
Expand Down

0 comments on commit cbb6f50

Please sign in to comment.