diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index da0ec10..2fb654f 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,27 +1,36 @@
# How to contribute to the BESS-KGE project
-You can contribute to the development of the BESS-KGE project, even if you don't have access to IPUs (you can use the [IPUModel](https://docs.graphcore.ai/projects/poptorch-user-guide/en/3.2.0/reference.html#poptorch.Options.useIpuModel) to emulate most functionalities of the physical hardware).
+![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)
+
+You can contribute to the development of the BESS-KGE project, even if you don't have access to IPUs (you can use the [IPUModel](https://docs.graphcore.ai/projects/poptorch-user-guide/en/3.2.0/reference.html#poptorch.Options.useIpuModel) to emulate most functionalities of the physical hardware).
## VS Code server on Paperspace
-Setting up a VS Code server on [Paperspace](https://www.paperspace.com/graphcore) will allow you to tunnel into a machine with IPUs from the VS Code web editor or the desktop app. This requires minimum effort and is an excellent solution for developing and testing code directly on IPU hardware.
+Setting up a VS Code server on [Paperspace](https://www.paperspace.com/graphcore) will allow you to tunnel into a machine with IPUs from the VS Code web editor or the desktop app. This requires minimum effort and is an excellent solution for developing and testing code directly on IPU hardware. Here's how to do it.
-You can launch a 6-hours session on a Paperspace machine with access to 4 IPUs **for free** by clicking on this button:
+1. Fork the [BESS-KGE repository](https://github.com/graphcore-research/bess-kge).
-Start the machine (this will also clone the repo for you) and open up a terminal from the left pane.
+2. You can launch a 6-hours session on a Paperspace machine with access to 4 IPUs **for free** by using a link of the form:
+ ```
+ https://console.paperspace.com/github/{USERID}/{REPONAME}?container=graphcore%2Fpytorch-paperspace%3A3.3.0-ubuntu-20.04-20230703&machine=Free-IPU-POD4
+ ```
-![terminal_pane](docs/source/images/Terminal1.png "height=200")
+ where `{USERID}/{REPOPNAME}` is the github address of the forked repository (e.g. `graphcore-research/bess-kge` for the original repo).
-In the terminal, run the command
-```shell
-bash .gradient/launch_vscode_server.sh {tunnel-name}
-```
+3. Start the machine (this will also clone the repo for you) and open up a terminal from the left pane.
+
+ ![terminal_pane](docs/source/images/Terminal3.png)
-where `tunnel-name` is an optional argument that you can use to define the name of the remote tunnel (if not set, it will default to `ipu-paperspace`).
+4. In the terminal, run the command
+ ```shell
+ bash .gradient/launch_vscode_server.sh {tunnel-name}
+ ```
-The script will download and install all dependencies and start the tunnel. You will be asked to authorize the tunnel through GitHub, before being provided with the tunnel link. Please refer to [this notebook](https://ipu.dev/fmo4AZ) for additional details on these steps and to connect the VS Code desktop app to the remote tunnel.
+ where `tunnel-name` is an optional argument that you can use to define the name of the remote tunnel (if not set, it will default to `ipu-paperspace`). The script will download and install all dependencies and start the tunnel.
-Once VS Code is connected to the Paperspace machine, run `./dev build` to build all custom ops. You are now ready to create a new git branch and start developing!
+5. When asked, authorize the tunnel through GitHub (with an account having writing privileges to the forked repository). You will be then provided with the tunnel link. Please refer to [this notebook](https://ipu.dev/fmo4AZ) for additional details on these steps and to connect the VS Code desktop app to the remote tunnel.
+
+6. Once VS Code is connected to the Paperspace machine, run `./dev build` to build all custom ops. You are now ready to start developing!
When closing a session and stopping the Paperspace machine, remember to unregister the tunnel in VS Code as explained in the "Common Issues" paragraph of the [notebook](https://ipu.dev/fmo4AZ). To resume your work, just access the clone of the BESS-KGE repo in the "Projects" section of your Paperspace profile, start a new machine and repeat the operations above. All code changes to the local repo, as well as VS Code settings and extensions installed, will persist across sessions.
@@ -43,10 +52,14 @@ pip install $POPLAR_SDK_ENABLED/../poptorch-*.whl
pip install -r requirements-dev.txt
```
-Finally, build all custom ops by running `./dev build`
+Finally, clone your fork of the BESS-KGE repository and build all custom ops by running `./dev build`
+
+## Development tips
+
+The `./dev` command can be used to run several utility scripts during development. Check `./dev --help` for a list of dev options.
-## Tips
+Before submitting a PR to the upstream repo, use `./dev ci` to run all CI checks locally. In particular, be mindful of our formatting requirements: you can check for formatting errors by running `./dev format` and `./dev lint` (both commands are automatically run inside `./dev ci`).
-Run `./dev --help` for a list of dev options. In particular, use `./dev ci` to run all CI checks locally. Run individual tests with pattern matching filtering `./dev tests -k FILTER`.
+Add unit tests to the `tests` folder. You can run individual unit tests with pattern matching filtering `./dev tests -k FILTER`.
Add `.cpp` custom ops to `besskge/custom_ops`. Also, update the [Makefile](Makefile) when adding custom ops.
diff --git a/NOTICE.md b/NOTICE.md
index cd44ef3..18e28d2 100644
--- a/NOTICE.md
+++ b/NOTICE.md
@@ -2,9 +2,7 @@ Copyright (c) 2023 Graphcore Ltd. Licensed under the MIT License.
The included code is released under an MIT license, (see [LICENSE](LICENSE)).
-The ogbl-biokg and ogbl-wikikg2 datasets are licensed under CC-0.
-
-The [YAGO3 dataset](https://yago-knowledge.org/downloads/yago-3) by the [YAGO team](https://yago-knowledge.org/contributors) of the [Max-Planck Institute for Informatics](https://www.mpi-inf.mpg.de/home/) and [Telcom Paris](https://www.telecom-paris.fr/) is licensed under [CC BY 3.0](https://creativecommons.org/licenses/by/3.0/).
+## Dependencies
Our dependencies are (see [requirements.txt](requirements.txt)):
@@ -18,8 +16,21 @@ Our dependencies are (see [requirements.txt](requirements.txt)):
We also use additional Python dependencies for development/testing/documentation (see [requirements-dev.txt](requirements-dev.txt)).
+## Dataset disclaimer
+
+This repository provides dataloaders for third party datasets. The use of these datasets is at own risk and Graphcore offers no warranties of any kind. It is the user's responsibility to comply with all license requirements for datasets downloaded with dataloaders in this repository.
+
+The tutorial notebooks make use of the following datasets:
+
+* [ogbl-biokg](https://ogb.stanford.edu/docs/linkprop/#ogbl-biokg), licensed under CC-0;
+
+* [ogbl-wikikg2](https://ogb.stanford.edu/docs/linkprop/#ogbl-wikikg2), licensed under CC-0;
+
+* [YAGO3 dataset](https://yago-knowledge.org/downloads/yago-3) by the [YAGO team](https://yago-knowledge.org/contributors) of the [Max-Planck Institute for Informatics](https://www.mpi-inf.mpg.de/home/) and [Telcom Paris](https://www.telecom-paris.fr/), licensed under [CC BY 3.0](https://creativecommons.org/licenses/by/3.0/).
+
+## Derived work
-**This directory includes derived work from the following:**
+This directory includes derived work from the following:
---
diff --git a/README.md b/README.md
index 909b229..20cd1ff 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,6 @@
# BESS-KGE
![Continuous integration](https://github.com/graphcore-research/bess-kge/actions/workflows/ci.yaml/badge.svg)
+![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)
[**Installation guide**](#usage)
| [**Tutorials**](#paperspace-notebook-tutorials)
@@ -77,6 +78,17 @@ Additional variations of the distribution scheme are detailed in the [BESS-KGE d
All APIs are documented in the [BESS-KGE API documentation](https://graphcore-research.github.io/bess-kge/API_reference.html).
+### Datasets
+
+BESS-KGE provides built-in dataloaders for the following datasets. Notice that the use of these datasets is at own risk and Graphcore offers no warranties of any kind. It is the user's responsibility to comply with all license requirements for datasets downloaded with dataloaders in this repository.
+
+| Dataset | Builder method | Entities | Entity types | Relation types | Triples | License |
+| --- | --- | --- | --- | --- | --- | --- |
+| [ogbl-biokg](https://ogb.stanford.edu/docs/linkprop/#ogbl-biokg) | [KGDataset.build_ogbl_biokg](https://graphcore-research.github.io/bess-kge/generated/besskge.dataset.KGDataset.html#besskge.dataset.KGDataset.build_ogbl_biokg) | 93,773 | 5 | 51 | 5,088,434 | CC-0 |
+| [ogbl-wikikg2](https://ogb.stanford.edu/docs/linkprop/#ogbl-wikikg2) | [KGDataset.build_ogbl_wikikg2](https://graphcore-research.github.io/bess-kge/generated/besskge.dataset.KGDataset.html#besskge.dataset.KGDataset.build_ogbl_wikikg2) | 2,500,604 | 1 | 535 | 16,968,094 | CC-0 |
+| [YAGO3-10](https://yago-knowledge.org/downloads/yago-3) | [KGDataset.build_yago310](https://graphcore-research.github.io/bess-kge/generated/besskge.dataset.KGDataset.html#besskge.dataset.KGDataset.build_yago310) | 123,182 | 1 | 37 | 1,089,040 | CC BY 3.0 |
+| [OpenBioLink2020](https://github.com/openbiolink/openbiolink#benchmark-dataset) | [KGDataset.build_openbiolink](https://graphcore-research.github.io/bess-kge/generated/besskge.dataset.KGDataset.html#besskge.dataset.KGDataset.build_openbiolink) | 184,635 | 7 | 28 | 4,563,405 | [link](https://github.com/openbiolink/openbiolink#Source-databases-and-their-licenses) |
+
### Known limitations
* BESS-KGE supports distribution for up to 16 IPUs.
@@ -178,10 +190,11 @@ For a walkthrough of the `besskge` library functionalities, see our Jupyter note
2. [Link prediction on the YAGO3-10 dataset](notebooks/2_yago_topk_prediction.ipynb) [![Run on Gradient](docs/gradient-badge.svg)](https://console.paperspace.com/github/graphcore-research/bess-kge?container=graphcore%2Fpytorch-paperspace%3A3.3.0-ubuntu-20.04-20230703&machine=Free-IPU-POD4&file=%2Fnotebooks%2F2_yago_topk_prediction.ipynb)
3. [FP16 weights and compute on the OGBL-WikiKG2 dataset](notebooks/3_wikikg2_fp16.ipynb) [![Run on Gradient](docs/gradient-badge.svg)](https://console.paperspace.com/github/graphcore-research/bess-kge?container=graphcore%2Fpytorch-paperspace%3A3.3.0-ubuntu-20.04-20230703&machine=Free-IPU-POD4&file=%2Fnotebooks%2F3_wikikg2_fp16.ipynb)
+For pointers on how to run BESS-KGE on a custom Knowledge Graph dataset, see the notebook [Using BESS-KGE with your own data](notebooks/0_custom_KG_dataset.ipynb) [![Run on Gradient](docs/gradient-badge.svg)](https://console.paperspace.com/github/graphcore-research/bess-kge?container=graphcore%2Fpytorch-paperspace%3A3.3.0-ubuntu-20.04-20230703&machine=Free-IPU-POD4&file=%2Fnotebooks%2F0_custom_KG_dataset.ipynb)
## Contributing
-You can contribute to the BESS-KGE project. See [How to contribute to the BESS-KGE project](CONTRIBUTING.md)
+You can contribute to the BESS-KGE project: PRs are welcome! For details, see [How to contribute to the BESS-KGE project](CONTRIBUTING.md).
## References
BESS: Balanced Entity Sampling and Sharing for Large-Scale Knowledge Graph Completion ([arXiv](https://arxiv.org/abs/2211.12281))
@@ -190,6 +203,6 @@ BESS: Balanced Entity Sampling and Sharing for Large-Scale Knowledge Graph Compl
Copyright (c) 2023 Graphcore Ltd. Licensed under the MIT License.
-The included code is released under the MIT license, (see [details of the license](LICENSE)).
+The included code is released under the MIT license (see [details of the license](LICENSE)).
See [notices](NOTICE.md) for dependencies, credits, derived work and further details.
\ No newline at end of file
diff --git a/besskge/dataset.py b/besskge/dataset.py
index 85e760a..17a12d1 100644
--- a/besskge/dataset.py
+++ b/besskge/dataset.py
@@ -8,9 +8,10 @@
import dataclasses
import pickle
import tarfile
+import zipfile
from io import BytesIO
from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import ogb.linkproppred
@@ -75,7 +76,153 @@ def ht_types(self) -> Optional[Dict[str, NDArray[np.int32]]]:
return None
@classmethod
- def build_biokg(cls, root: Path) -> "KGDataset":
+ def from_triples(
+ cls,
+ data: NDArray[np.int32],
+ split: Tuple[float, float, float] = (0.7, 0.15, 0.15),
+ seed: int = 1234,
+ entity_dict: Optional[List[str]] = None,
+ relation_dict: Optional[List[str]] = None,
+ type_offsets: Optional[Dict[str, int]] = None,
+ ) -> "KGDataset":
+ """
+ Build a dataset from an array of triples, where IDs for entities
+ and relations have already been assigned. Note that, if entities have
+ types, entities of the same type need to have contiguous IDs.
+ Triples are randomly split in train/validation/test sets.
+ If a pre-defined train/validation/test split is wanted, the KGDataset
+ class should be instantiated manually.
+
+ :param data:
+ Numpy array of triples [head_id, relation_id, tail_id]. Shape
+ (num_triples, 3).
+ :param split:
+ Tuple to set the train/validation/test split.
+ :param seed:
+ Random seed for the train/validation/test split.
+ :param entity_dict:
+ Optional entity labels by ID.
+ :param relation_dict:
+ Optional relation labels by ID.
+ :param type_offsets:
+ Offset of entity types
+
+ :return: Instance of the KGDataset class.
+ """
+ num_triples = data.shape[0]
+ num_train = int(num_triples * split[0])
+ num_valid = int(num_triples * split[1])
+
+ rng = np.random.default_rng(seed=seed)
+ rng.shuffle(data, axis=0)
+
+ triples = dict()
+ triples["train"], triples["valid"], triples["test"] = np.split(
+ data, (num_train, num_train + num_valid), axis=0
+ )
+
+ return cls(
+ n_entity=data[:, [0, 2]].max() + 1,
+ n_relation_type=data[:, 1].max() + 1,
+ entity_dict=entity_dict,
+ relation_dict=relation_dict,
+ type_offsets=type_offsets,
+ triples=triples,
+ )
+
+ @classmethod
+ def from_dataframe(
+ cls,
+ df: Union[pd.DataFrame, Dict[str, pd.DataFrame]],
+ head_column: Union[int, str],
+ relation_column: Union[int, str],
+ tail_column: Union[int, str],
+ entity_types: Optional[Union["pd.Series[str]", Dict[str, str]]] = None,
+ split: Tuple[float, float, float] = (0.7, 0.15, 0.15),
+ seed: int = 1234,
+ ) -> "KGDataset":
+ """
+ Build a KGDataset from a pandas DataFrame of labeled (h,r,t) triples.
+ IDs for entities and relations are automatically assigned based on labels
+ in such a way that entities of the same type have contiguous IDs.
+
+ :param df:
+ Pandas DataFrame of all triples in the knowledge graph dataset,
+ or dictionary of DataFrames of triples for each part of the dataset split
+ :param head_column:
+ Name of the DataFrame column storing head entities
+ :param relation_column:
+ Name of the DataFrame column storing relations
+ :param tail_column:
+ Name of the DataFrame column storing tail entities
+ :param entity_types:
+ If entities have types, dictionary or pandas Series of mappings
+ entity label -> entity type.
+ :param split:
+ Tuple to set the train/validation/test split.
+ Only used if no pre-defined dataset split is specified,
+ i.e. if `df` is not a dictionary.
+ :param seed:
+ Random seed for the train/validation/test split.
+ Only used if no pre-defined dataset split is specified,
+ i.e. if `df` is not a dictionary.
+
+ :return: Instance of the KGDataset class.
+ """
+
+ df_dict = {"all": df} if isinstance(df, pd.DataFrame) else df
+ unique_ent = pd.concat(
+ [
+ pd.concat([dfp[head_column], dfp[tail_column]])
+ for dfp in df_dict.values()
+ ]
+ ).unique()
+ ent2id = pd.Series(np.arange(len(unique_ent)), index=unique_ent, name="ent_id")
+ unique_rel = pd.concat(
+ [dfp[relation_column] for dfp in df_dict.values()]
+ ).unique()
+ rel2id = pd.Series(np.arange(len(unique_rel)), index=unique_rel, name="rel_id")
+
+ if entity_types is not None:
+ ent2type = pd.Series(entity_types, name="ent_type")
+ ent2id_type = pd.merge(
+ ent2id, ent2type, how="left", left_index=True, right_index=True
+ ).sort_values("ent_type")
+ ent2id.index = ent2id_type.index
+ type_off = (
+ ent2id_type.groupby("ent_type")["ent_type"].count().cumsum().shift(1)
+ )
+ type_off.iloc[0] = 0
+ type_offsets = type_off.astype("int64").to_dict()
+ else:
+ type_offsets = None
+
+ entity_dict = ent2id.index.tolist()
+ relation_dict = rel2id.index.tolist()
+
+ triples = {}
+ for part, dfp in df_dict.items():
+ heads = dfp[head_column].map(ent2id).values.astype(np.int32)
+ tails = dfp[tail_column].map(ent2id).values.astype(np.int32)
+ rels = dfp[relation_column].map(rel2id).values.astype(np.int32)
+ triples[part] = np.stack([heads, rels, tails], axis=1)
+
+ if isinstance(df, pd.DataFrame):
+ return KGDataset.from_triples(
+ triples["all"], split, seed, entity_dict, relation_dict, type_offsets
+ )
+ else:
+ return cls(
+ n_entity=len(entity_dict),
+ n_relation_type=len(relation_dict),
+ entity_dict=entity_dict,
+ relation_dict=relation_dict,
+ type_offsets=type_offsets,
+ triples=triples,
+ )
+
+ @classmethod
+ def build_ogbl_biokg(cls, root: Path) -> "KGDataset":
"""
Build the ogbl-biokg dataset :cite:p:`OGB`
@@ -138,7 +285,7 @@ def build_biokg(cls, root: Path) -> "KGDataset":
)
@classmethod
- def build_wikikg2(cls, root: Path) -> "KGDataset":
+ def build_ogbl_wikikg2(cls, root: Path) -> "KGDataset":
"""
Build the ogbl-wikikg2 dataset :cite:p:`OGB`
@@ -215,130 +362,82 @@ def build_yago310(cls, root: Path) -> "KGDataset":
with tarfile.open(fileobj=BytesIO(res.content)) as tarf:
tarf.extractall(path=root)
- train = np.loadtxt(root.joinpath("train.txt"), delimiter="\t", dtype=str)
- valid = np.loadtxt(root.joinpath("valid.txt"), delimiter="\t", dtype=str)
- test = np.loadtxt(root.joinpath("test.txt"), delimiter="\t", dtype=str)
-
- entity_dict, entity_id = np.unique(
- np.concatenate(
- [
- train[:, 0],
- train[:, 2],
- valid[:, 0],
- valid[:, 2],
- test[:, 0],
- test[:, 2],
- ]
- ),
- return_inverse=True,
- )
- entity_split_limits = np.cumsum(
- [
- train.shape[0],
- train.shape[0],
- valid.shape[0],
- valid.shape[0],
- test.shape[0],
- ]
+ train_triples = pd.read_csv(
+ root.joinpath("train.txt"), delimiter="\t", dtype=str, header=None
)
- (
- train_head_id,
- train_tail_id,
- validation_head_id,
- validation_tail_id,
- test_head_id,
- test_tail_id,
- ) = np.split(entity_id, entity_split_limits)
-
- rel_dict, rel_id = np.unique(
- np.concatenate([train[:, 1], valid[:, 1], test[:, 1]]),
- return_inverse=True,
+ valid_triples = pd.read_csv(
+ root.joinpath("valid.txt"), delimiter="\t", dtype=str, header=None
)
- relation_split_limits = np.cumsum([train.shape[0], valid.shape[0]])
- train_rel_id, validation_rel_id, test_rel_id = np.split(
- rel_id, relation_split_limits
+ test_triples = pd.read_csv(
+ root.joinpath("test.txt"), delimiter="\t", dtype=str, header=None
)
- triples = {
- "train": np.concatenate(
- [train_head_id[:, None], train_rel_id[:, None], train_tail_id[:, None]],
- axis=1,
- ),
- "validation": np.concatenate(
- [
- validation_head_id[:, None],
- validation_rel_id[:, None],
- validation_tail_id[:, None],
- ],
- axis=1,
- ),
- "test": np.concatenate(
- [test_head_id[:, None], test_rel_id[:, None], test_tail_id[:, None]],
- axis=1,
- ),
- }
-
- return cls(
- n_entity=len(entity_dict),
- n_relation_type=len(rel_dict),
- entity_dict=entity_dict.tolist(),
- relation_dict=rel_dict.tolist(),
- type_offsets=None,
- triples=triples,
- neg_heads=None,
- neg_tails=None,
+ return cls.from_dataframe(
+ {"train": train_triples, "valid": valid_triples, "test": test_triples},
+ head_column=0,
+ relation_column=1,
+ tail_column=2,
)
@classmethod
- def from_triples(
- cls,
- data: NDArray[np.int32],
- split: Tuple[float, float, float] = (0.7, 0.15, 0.15),
- seed: int = 1234,
- entity_dict: Optional[List[str]] = None,
- relation_dict: Optional[List[str]] = None,
- type_offsets: Optional[Dict[str, int]] = None,
- ) -> "KGDataset":
+ def build_openbiolink(cls, root: Path) -> "KGDataset":
"""
- Build a dataset from an array of triples. Note that if a pre-defined
- train/validation/test split is wanted the KGDataset class should be instantiated
- manually.
+ Build the high-quality version of the OpenBioLink2020
+ dataset :cite:p:`openbiolink`
- :param data:
- Numpy array of triples [head_id, relation_id, tail_id]. Shape
- (num_triples, 3).
- :param split:
- Tuple to set the train/validation/test split.
- :param seed:
- Random seed for the train/validation/test split.
- :param entity_dict:
- Optional entity labels by ID.
- :param relation_dict:
- Optional relation labels by ID.
- :param type_offsets:
- Offset of entity types
+ .. seealso:: https://github.com/openbiolink/openbiolink#benchmark-dataset
- :return: Instance of the KGDataset class.
- """
- num_triples = data.shape[0]
- num_train = int(num_triples * split[0])
- num_valid = int(num_triples * split[1])
+ :param root:
+ Local path to the dataset. If the dataset is not present in this
+ location, then it is downloaded and stored here.
- rng = np.random.default_rng(seed=seed)
- rng.shuffle(data, axis=0)
+ :return: The HQ OpenBioLink2020 KGDataset.
+ """
- triples = dict()
- triples["train"], triples["valid"], triples["test"] = np.split(
- data, (num_train, num_train + num_valid), axis=0
+ if not (
+ root.joinpath("HQ_DIR/train_test_data/train_sample.csv").is_file()
+ and root.joinpath("HQ_DIR/train_test_data/val_sample.csv").is_file()
+ and root.joinpath("HQ_DIR/train_test_data/test_sample.csv").is_file()
+ and root.joinpath("HQ_DIR/train_test_data/train_val_nodes.csv").is_file()
+ ):
+ print("Downloading dataset...")
+ res = requests.get(url="https://zenodo.org/record/3834052/files/HQ_DIR.zip")
+ with zipfile.ZipFile(BytesIO(res.content)) as zip_f:
+ zip_f.extractall(path=root)
+
+ column_names = ["h_label", "r_label", "t_label", "quality", "TP/TN", "source"]
+ train_triples = pd.read_csv(
+ root.joinpath("HQ_DIR/train_test_data/train_sample.csv"),
+ header=None,
+ names=column_names,
+ sep="\t",
+ )
+ valid_triples = pd.read_csv(
+ root.joinpath("HQ_DIR/train_test_data/val_sample.csv"),
+ header=None,
+ names=column_names,
+ sep="\t",
+ )
+ test_triples = pd.read_csv(
+ root.joinpath("HQ_DIR/train_test_data/test_sample.csv"),
+ header=None,
+ names=column_names,
+ sep="\t",
)
- return cls(
- n_entity=data[:, [0, 2]].max() + 1,
- n_relation_type=data[:, 1].max() + 1,
- entity_dict=entity_dict,
- relation_dict=relation_dict,
- type_offsets=type_offsets,
- triples=triples,
+ entity_types = pd.read_csv(
+ root.joinpath("HQ_DIR/train_test_data/train_val_nodes.csv"),
+ header=None,
+ names=["ent_label", "ent_type"],
+ sep="\t",
+ ).set_index("ent_label")["ent_type"]
+
+ return cls.from_dataframe(
+ {"train": train_triples, "valid": valid_triples, "test": test_triples},
+ head_column="h_label",
+ relation_column="r_label",
+ tail_column="t_label",
+ entity_types=entity_types,
)
def save(self, out_file: Path) -> None:
diff --git a/docs/source/KGbib.bib b/docs/source/KGbib.bib
index fa9e948..d1ad271 100644
--- a/docs/source/KGbib.bib
+++ b/docs/source/KGbib.bib
@@ -192,3 +192,13 @@ @inproceedings{TranS
pages ={1202--1208},
year ={2022}
}
+
+@article{openbiolink,
+ author = {Breit Anna and Ott Simon and Agibetov Asan and Samwald Matthias},
+ title = {{OpenBioLink: a benchmarking framework for large-scale biomedical link prediction}},
+ journal = {Bioinformatics},
+ volume = {36},
+ number = {13},
+ pages = {4097-4098},
+ year = {2020},
+}
diff --git a/docs/source/images/Terminal1.png b/docs/source/images/Terminal1.png
deleted file mode 100644
index 98ca224..0000000
Binary files a/docs/source/images/Terminal1.png and /dev/null differ
diff --git a/docs/source/images/Terminal3.png b/docs/source/images/Terminal3.png
new file mode 100644
index 0000000..fc684dc
Binary files /dev/null and b/docs/source/images/Terminal3.png differ
diff --git a/notebooks/0_custom_KG_dataset.ipynb b/notebooks/0_custom_KG_dataset.ipynb
new file mode 100644
index 0000000..f788047
--- /dev/null
+++ b/notebooks/0_custom_KG_dataset.ipynb
@@ -0,0 +1,988 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Using BESS-KGE with your Own Data\n",
+ "\n",
+ "Copyright (c) 2023 Graphcore Ltd. All rights reserved.\n",
+ "\n",
+ "BESS-KGE (`besskge`) is a PyTorch library for knowledge graph embedding (KGE) models on IPUs implementing the distribution framework [BESS](https://arxiv.org/abs/2211.12281), with embedding tables stored in the IPU SRAM.\n",
+ "\n",
+ "In this notebook we will show how to use the `besskge.dataset.KGDataset` class to easily pre-process a custom knowledge graph dataset for use with BESS-KGE.\n",
+ "\n",
+ "As an example, we will download and build the [OGBL-BioKG](https://ogb.stanford.edu/docs/linkprop/#ogbl-biokg) biomedical knowledge graph. While BESS-KGE provides a built-in dataloader for this dataset (see [besskge.dataset.KGDataset.build_ogbl_biokg](https://graphcore-research.github.io/bess-kge/generated/besskge.dataset.KGDataset.html#besskge.dataset.KGDataset.build_ogbl_biokg)), in this notebook — for didactic purposes — we will show how to import the dataset from scratch."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Environment setup\n",
+ "\n",
+ "While this notebook doesn't contain any code that needs to be run on IPUs or other accelerating hardware, the best way to run it is on Paperspace Gradient's cloud IPUs because everything is already set up for you.\n",
+ "\n",
+ " [![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://console.paperspace.com/github/graphcore-research/bess-kge?container=graphcore%2Fpytorch-paperspace%3A3.3.0-ubuntu-20.04-20230703&machine=Free-IPU-POD4&file=%2Fnotebooks%2F0_custom_KG_dataset.ipynb)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Dependencies"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We recommend that you install `besskge` directly from the GitHub sources:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found existing installation: besskge 0.1\n",
+ "Uninstalling besskge-0.1:\n",
+ " Successfully uninstalled besskge-0.1\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "!{sys.executable} -m pip uninstall -y besskge\n",
+ "!pip install -q git+https://github.com/graphcore-research/bess-kge.git"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, import the necessary dependencies:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from pathlib import Path\n",
+ "\n",
+ "import ogb\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "\n",
+ "from besskge.dataset import KGDataset\n",
+ "\n",
+ "dataset_directory = os.getenv(\"DATASET_DIR\", \"../datasets/\") + \"/biokg/\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## The dataset\n",
+ "\n",
+ "We download the OGBL-BioKG knowledge graph using the `ogb` package (see the OGB [description of the data loader](https://ogb.stanford.edu/docs/linkprop/#data-loader) for details on how to use it). It shouldn't take more than a couple of minutes for the dataset to download."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Downloading http://snap.stanford.edu/ogb/data/linkproppred/biokg.zip\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloaded 0.90 GB: 100%|██████████| 920/920 [01:18<00:00, 11.67it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Extracting ../datasets//biokg_test/biokg.zip\n",
+ "Loading necessary files...\n",
+ "This might take a while.\n",
+ "Processing graphs...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 1/1 [00:00<00:00, 4860.14it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Saving...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = ogb.linkproppred.LinkPropPredDataset(\n",
+ " name=\"ogbl-biokg\", root=dataset_directory\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Since the objective here is to show how to build the BESS-KGE `KGDataset` class from scratch, we will not use any of the pre-processing utilities provided by `ogb`. Instead we will use the raw source files directly. To start from the most generic case, we will actually undo some of the preprocessing already performed on the data, namely the mapping from entity labels to entity indices for the different entity types."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "# train triples: 4,762,678\n",
+ "# validation triples: 162,886\n",
+ "# test triples: 162,870\n"
+ ]
+ }
+ ],
+ "source": [
+ "label_dict = {}\n",
+ "for l in [\"disease\", \"drug\", \"function\", \"protein\", \"sideeffect\"]:\n",
+ " # labels for the entities of type l\n",
+ " # we append the entity type to the label to prevent label collisions across different types\n",
+ " # notice that some entities, like proteins, have labels that are still numerical\n",
+ " label_dict[l] = (\n",
+ " pd.read_csv(\n",
+ " Path(dataset_directory).joinpath(\n",
+ " f\"ogbl_biokg/mapping/{l}_entidx2name.csv.gz\"\n",
+ " )\n",
+ " ).set_index(\"ent idx\")[\"ent name\"]\n",
+ " + f\" ({l})\"\n",
+ " ).values\n",
+ "\n",
+ "# labels for relation types\n",
+ "rel_dict = (\n",
+ " pd.read_csv(\n",
+ " Path(dataset_directory).joinpath(f\"ogbl_biokg/mapping/relidx2relname.csv.gz\")\n",
+ " )\n",
+ " .set_index(\"rel idx\")[\"rel name\"]\n",
+ " .values\n",
+ ")\n",
+ "\n",
+ "# collect triples in train, valid, test DataFrames\n",
+ "# replacing the entity IDs with their original labels.\n",
+ "df_dict = {}\n",
+ "for split in {\"test\", \"train\", \"valid\"}:\n",
+ " triples = []\n",
+ " data = torch.load(\n",
+ " Path(dataset_directory).joinpath(f\"ogbl_biokg/split/random/{split}.pt\")\n",
+ " )\n",
+ " for h, h_type, t, t_type, r in zip(\n",
+ " data[\"head\"],\n",
+ " data[\"head_type\"],\n",
+ " data[\"tail\"],\n",
+ " data[\"tail_type\"],\n",
+ " data[\"relation\"],\n",
+ " ):\n",
+ " h_label = label_dict[h_type][h]\n",
+ " t_label = label_dict[t_type][t]\n",
+ " r_label = rel_dict[r]\n",
+ " triples.append((h_label, r_label, t_label))\n",
+ " df_dict[split] = pd.DataFrame(\n",
+ " triples, columns=[\"head_label\", \"relation_label\", \"tail_label\"]\n",
+ " )\n",
+ "\n",
+ "print(f\"# train triples: {df_dict['train'].shape[0]:,}\")\n",
+ "print(f\"# validation triples: {df_dict['valid'].shape[0]:,}\")\n",
+ "print(f\"# test triples: {df_dict['test'].shape[0]:,}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " head_label | \n",
+ " relation_label | \n",
+ " tail_label | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " C0038586 (disease) | \n",
+ " disease-protein | \n",
+ " 1653 (protein) | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " C0751849 (disease) | \n",
+ " disease-protein | \n",
+ " 718 (protein) | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " C1320474 (disease) | \n",
+ " disease-protein | \n",
+ " 8622 (protein) | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " C0270844 (disease) | \n",
+ " disease-protein | \n",
+ " 3569 (protein) | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " C4279912 (disease) | \n",
+ " disease-protein | \n",
+ " 8856 (protein) | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " head_label relation_label tail_label\n",
+ "0 C0038586 (disease) disease-protein 1653 (protein)\n",
+ "1 C0751849 (disease) disease-protein 718 (protein)\n",
+ "2 C1320474 (disease) disease-protein 8622 (protein)\n",
+ "3 C0270844 (disease) disease-protein 3569 (protein)\n",
+ "4 C4279912 (disease) disease-protein 8856 (protein)"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_dict[\"train\"].head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We put ourselves in a very generic starting point, where the **edges of the knowledge graph are represented as a list of (head, relation, tail) triples**, with *unique labels* for entities and relations.\n",
+ "\n",
+ "In our knowledge graph, moreover, entities are of different types (disease, drug, function, protein and side-effect). This is not always the case, but BESS-KGE can leverage this additional information, for instance by constructing negative samples corrupting entities only with entities of the same type (see [besskge.negative_sampler.TypeBasedNegativeSampler](https://graphcore-research.github.io/bess-kge/API_ref/negative_sampler.html#besskge.negative_sampler.TypeBasedShardedNegativeSampler)). We store this data by creating a mapping from entity labels to entity types."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " ent_label | \n",
+ " ent_type | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " C0000737 (disease) | \n",
+ " disease | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " C0000744 (disease) | \n",
+ " disease | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " C0000768 (disease) | \n",
+ " disease | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " C0000771 (disease) | \n",
+ " disease | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " C0000772 (disease) | \n",
+ " disease | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 93768 | \n",
+ " C3665444 (sideeffect) | \n",
+ " sideeffect | \n",
+ "
\n",
+ " \n",
+ " 93769 | \n",
+ " C3665596 (sideeffect) | \n",
+ " sideeffect | \n",
+ "
\n",
+ " \n",
+ " 93770 | \n",
+ " C3665609 (sideeffect) | \n",
+ " sideeffect | \n",
+ "
\n",
+ " \n",
+ " 93771 | \n",
+ " C3665624 (sideeffect) | \n",
+ " sideeffect | \n",
+ "
\n",
+ " \n",
+ " 93772 | \n",
+ " C3665888 (sideeffect) | \n",
+ " sideeffect | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
93773 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " ent_label ent_type\n",
+ "0 C0000737 (disease) disease\n",
+ "1 C0000744 (disease) disease\n",
+ "2 C0000768 (disease) disease\n",
+ "3 C0000771 (disease) disease\n",
+ "4 C0000772 (disease) disease\n",
+ "... ... ...\n",
+ "93768 C3665444 (sideeffect) sideeffect\n",
+ "93769 C3665596 (sideeffect) sideeffect\n",
+ "93770 C3665609 (sideeffect) sideeffect\n",
+ "93771 C3665624 (sideeffect) sideeffect\n",
+ "93772 C3665888 (sideeffect) sideeffect\n",
+ "\n",
+ "[93773 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "entity_types = {l: t for t in label_dict.keys() for l in label_dict[t]}\n",
+ "entity_types = pd.DataFrame(\n",
+ " {\"ent_label\": entity_types.keys(), \"ent_type\": entity_types.values()}\n",
+ ")\n",
+ "entity_types"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ent_type\n",
+ "disease 10687\n",
+ "drug 10533\n",
+ "function 45085\n",
+ "protein 17499\n",
+ "sideeffect 9969\n",
+ "Name: ent_type, dtype: int64"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# entity type counts\n",
+ "entity_types.groupby(\"ent_type\")[\"ent_type\"].count()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## The KGDataset class in BESS-KGE\n",
+ "\n",
+ "When using BESS-KGE, the knowledge graph data is stored in an instance of the [besskge.dataset.KGDataset](https://graphcore-research.github.io/bess-kge/API_ref/dataset.html#besskge.dataset.KGDataset) class. This class has built-in methods to download and build some commonly-used knowledge graph datasets, or it can be instantiated manually with custom data by specifying all the required attributes.\n",
+ "\n",
+ "The `besskge.dataset.KGDataset.from_dataframe` method is perfect to **build a custom dataset starting from labelled triples with minimum effort**. It simply requires a pandas DataFrame containing all labelled triples (or a dictionary of DataFrames, one for each of the dataset splits). If entities are of different types, like for OGBL-BioKG, this can be communicated to `KGDataset` by providing it with a mapping of entity labels to entity types, in the form of a dictionary or a pandas Series, indexed over the entity labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ent_label\n",
+ "C0000737 (disease) disease\n",
+ "C0000744 (disease) disease\n",
+ "C0000768 (disease) disease\n",
+ "C0000771 (disease) disease\n",
+ "C0000772 (disease) disease\n",
+ " ... \n",
+ "C3665444 (sideeffect) sideeffect\n",
+ "C3665596 (sideeffect) sideeffect\n",
+ "C3665609 (sideeffect) sideeffect\n",
+ "C3665624 (sideeffect) sideeffect\n",
+ "C3665888 (sideeffect) sideeffect\n",
+ "Name: ent_type, Length: 93773, dtype: object"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "entity_types_series = entity_types.set_index(\"ent_label\")[\"ent_type\"]\n",
+ "entity_types_series"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To build `KGDataset` we only need to call `KGDataset.from_dataframe`, specifying the names of the columns of the dataframes which contain the entity and relation labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "biokg = KGDataset.from_dataframe(\n",
+ " df_dict,\n",
+ " head_column=\"head_label\",\n",
+ " relation_column=\"relation_label\",\n",
+ " tail_column=\"tail_label\",\n",
+ " entity_types=entity_types_series,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "That was easy, wasn't it? Let us have a closer look at the attributes of the `KGDataset` class we just created."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entities: 93,773\n",
+ "\n",
+ "Number of relation types: 51\n",
+ "\n",
+ "Number of triples: \n",
+ " training: 4,762,678 \n",
+ " validation: 162,886\n",
+ " test: 162,870\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Number of entities: {biokg.n_entity:,}\\n\")\n",
+ "print(f\"Number of relation types: {biokg.n_relation_type}\\n\")\n",
+ "print(\n",
+ " f\"Number of triples: \\n training: {biokg.triples['train'].shape[0]:,} \\n validation: {biokg.triples['valid'].shape[0]:,}\"\n",
+ " f\"\\n test: {biokg.triples['test'].shape[0]:,}\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that only the entities that appear as head or tail in at least one of the dataset's triples are counted in `KGDataset.n_entity` (and the same for relations in `KGDataset.n_relation_type`).\n",
+ "\n",
+ "If we take a look at the `triples` attribute of `KGDataset`, we see that it is still structured as a dictionary, with the same keys that we used in `df_dict` to identify the different dataset splits."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "dict_keys(['valid', 'test', 'train'])"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "biokg.triples.keys()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[10120, 0, 79687],\n",
+ " [10137, 0, 79168],\n",
+ " [ 5659, 0, 78209],\n",
+ " ...,\n",
+ " [73643, 50, 73701],\n",
+ " [80650, 50, 72104],\n",
+ " [80643, 50, 76000]], dtype=int32)"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "biokg.triples[\"train\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each row of these NumPy arrays corresponds to a (h,r,t) triple in the dataset, but now the **labels for entities and relations have been replaced by numerical IDs**. We can use `KGDataset.entity_dict` and `KGDataset.relation_dict` to recover the mapping from IDs to labels for entities and relation types respectively."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['C0393778 (disease)',\n",
+ " 'C2677109 (disease)',\n",
+ " 'C0796093 (disease)',\n",
+ " 'C2931278 (disease)',\n",
+ " 'C0149910 (disease)',\n",
+ " 'C3668942 (disease)',\n",
+ " 'C4225263 (disease)',\n",
+ " 'C0393814 (disease)',\n",
+ " 'C0342883 (disease)',\n",
+ " 'C4017556 (disease)']"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "biokg.entity_dict[:10]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "meaning that the entity with ID 0 is \"C0393778 (disease)\", the entity with ID 1 is \"C2677109 (disease)\", etc. (and similarly for `biokg.relation_dict`).\n",
+ "\n",
+ "Let's do a quick sanity check."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "('C0751495 (disease)', 'disease-protein', '7249 (protein)')"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "part = \"test\"\n",
+ "triple_number = 1234\n",
+ "\n",
+ "head_label = biokg.entity_dict[biokg.triples[part][triple_number, 0]]\n",
+ "relation_label = biokg.relation_dict[biokg.triples[part][triple_number, 1]]\n",
+ "tail_label = biokg.entity_dict[biokg.triples[part][triple_number, 2]]\n",
+ "\n",
+ "head_label, relation_label, tail_label"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In the original dataframes, this should coincide with:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "head_label C0751495 (disease)\n",
+ "relation_label disease-protein\n",
+ "tail_label 7249 (protein)\n",
+ "Name: 1234, dtype: object"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_dict[part].iloc[triple_number][[\"head_label\", \"relation_label\", \"tail_label\"]]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It checks out!\n",
+ "\n",
+ "It is important to note that, when entities have different types, the numerical entity IDs need to be assigned so that **entities of the same type have contiguous IDs**! This is done automatically when using `KGDataset.from_dataframe`, but it needs to be kept in mind if you are instantiating the `KGDataset` class manually.\n",
+ "\n",
+ "Since entity IDs are now clustered by type, we only need to know the ID ranges corresponding to the different types, which are stored in `KGDataset.type_offsets`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'disease': 0,\n",
+ " 'drug': 10687,\n",
+ " 'function': 21220,\n",
+ " 'protein': 66305,\n",
+ " 'sideeffect': 83804}"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "biokg.type_offsets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This means that entities with ID from 0 to 10686 are of type 'disease', from 10687 to 21219 are of type 'drug' and so on.\n",
+ "\n",
+ "The type IDs (assigned following the order of the keys in `KGDataset.type_offsets`) for heads and tails of all triples in the dataset can be immediately recovered using `KGDataset.ht_types`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "dict_keys(['valid', 'test', 'train'])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "biokg.ht_types.keys()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(4762678, 2) (162886, 2) (162870, 2)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\n",
+ " biokg.ht_types[\"train\"].shape,\n",
+ " biokg.ht_types[\"valid\"].shape,\n",
+ " biokg.ht_types[\"test\"].shape\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each row of these arrays stores the ID of the type of the head entity and tail entity of the corresponding triple, for example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0, 3])"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "part = \"valid\"\n",
+ "triple_number = 0\n",
+ "\n",
+ "biokg.ht_types[part][triple_number]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This means that the first validation triple has a head entity of type 'disease' (the 0-th key of `biokg.type_offsets`) and a tail entity of type 'protein' (the key of `biokg.type_offsets` with index 3). Indeed, we can check:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "('disease', 'protein')"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type_head = entity_types_series[df_dict[part].iloc[triple_number][\"head_label\"]]\n",
+ "type_tail = entity_types_series[df_dict[part].iloc[triple_number][\"tail_label\"]]\n",
+ "\n",
+ "type_head, type_tail"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Random dataset split\n",
+ "\n",
+ "What if no custom train/validation/test split is provided? `KGDataset.from_dataframe` can perform a random split, with the desired ratios between the three parts. This happens whenever it is provided with a single pandas DataFrame, instead of a dictionary of DataFrames as before."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total number of triples: 5,088,434\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Merge all triples in a single DataFrame\n",
+ "df_all_triples = pd.concat(\n",
+ " [df_dict[\"train\"], df_dict[\"valid\"], df_dict[\"test\"]], axis=0\n",
+ ")\n",
+ "\n",
+ "print(f\"Total number of triples: {df_all_triples.shape[0]:,}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entities: 93,773\n",
+ "\n",
+ "Number of relation types: 51\n",
+ "\n",
+ "Number of triples: \n",
+ " training: 4,070,747 \n",
+ " validation: 508,843\n",
+ " test: 508,844\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 80/10/10 train/valid/test split\n",
+ "split_ratios = (0.8, 0.1, 0.1)\n",
+ "\n",
+ "biokg_random = KGDataset.from_dataframe(\n",
+ " df_all_triples,\n",
+ " head_column=\"head_label\",\n",
+ " relation_column=\"relation_label\",\n",
+ " tail_column=\"tail_label\",\n",
+ " entity_types=entity_types_series,\n",
+ " split=split_ratios,\n",
+ ")\n",
+ "\n",
+ "print(f\"Number of entities: {biokg_random.n_entity:,}\\n\")\n",
+ "print(f\"Number of relation types: {biokg_random.n_relation_type}\\n\")\n",
+ "print(\n",
+ " f\"Number of triples: \\n training: {biokg_random.triples['train'].shape[0]:,} \\n validation: {biokg_random.triples['valid'].shape[0]:,}\"\n",
+ " f\"\\n test: {biokg_random.triples['test'].shape[0]:,}\"\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Conclusions and next steps\n",
+ "\n",
+ "The `KGDataset` class has a few additional attributes, which can be defined when instantiating the class manually. For instance, for each triple it allows you to specify a set of negative heads and tails that should be used to corrupt that triple. This is useful when you want to find the best completion for a (h,r,?)/(?,r,t) query only among a specific set of candidate nodes, or you have already identified good negative samples that you want to use during training.\n",
+ "\n",
+ "For more information on the `KGDataset` class, have a look at the [BESS-KGE documentation](https://graphcore-research.github.io/bess-kge/API_ref/dataset.html#besskge.dataset.KGDataset).\n",
+ "\n",
+ "Once you have built your dataset as a `KGDataset` class, you are ready to use BESS to train your preferred KGE model and perform inference with it! To learn how, we suggest starting from the introductory [KGE Training and Inference on OGBL-BioKG](1_biokg_training_inference.ipynb) notebook."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv_3.2",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/1_biokg_training_inference.ipynb b/notebooks/1_biokg_training_inference.ipynb
index 1d1b59f..3e7a78f 100644
--- a/notebooks/1_biokg_training_inference.ipynb
+++ b/notebooks/1_biokg_training_inference.ipynb
@@ -112,7 +112,7 @@
"source": [
"## Sharding entities and triples\n",
"\n",
- "The OGBL-BioKG dataset can be downloaded and preprocessed with the built-in `KGDataset` method, `build_biokg`. `KGDataset` is the standard class which holds data from the knowledge graph dataset, such as head-relation-tail triples (with entities and relation types suitably converted to their integer IDs), triple-specific data (for example negative heads/tails to be used to corrupt the triple), ID -> label lists for entities and relation types, and so on."
+ "The OGBL-BioKG dataset can be downloaded and preprocessed with the built-in `KGDataset` method, `build_ogbl_biokg`. `KGDataset` is the standard class which holds data from the knowledge graph dataset, such as head-relation-tail triples (with entities and relation types suitably converted to their integer IDs), triple-specific data (for example negative heads/tails to be used to corrupt the triple), ID -> label lists for entities and relation types, and so on."
]
},
{
@@ -137,12 +137,16 @@
}
],
"source": [
- "biokg = KGDataset.build_biokg(root=pathlib.Path(dataset_directory))\n",
+ "biokg = KGDataset.build_ogbl_biokg(root=pathlib.Path(dataset_directory))\n",
"\n",
"print(f\"Number of entities: {biokg.n_entity:,}\\n\")\n",
"print(f\"Number of relation types: {biokg.n_relation_type}\\n\")\n",
- "print(f\"Number of triples: \\n training: {biokg.triples['train'].shape[0]:,} \\n validation/test: {biokg.triples['valid'].shape[0]:,}\\n\")\n",
- "print(f\"Number of negative heads/tails for validation/test triples: {biokg.neg_heads['valid'].shape[-1]}\")"
+ "print(\n",
+ " f\"Number of triples: \\n training: {biokg.triples['train'].shape[0]:,} \\n validation/test: {biokg.triples['valid'].shape[0]:,}\\n\"\n",
+ ")\n",
+ "print(\n",
+ " f\"Number of negative heads/tails for validation/test triples: {biokg.neg_heads['valid'].shape[-1]}\"\n",
+ ")"
]
},
{
@@ -227,7 +231,12 @@
"seed = 1234\n",
"n_shard = 4\n",
"\n",
- "sharding = Sharding.create(n_entity=biokg.n_entity, n_shard=n_shard, seed=seed, type_offsets=np.fromiter(biokg.type_offsets.values(), dtype=np.int32))\n",
+ "sharding = Sharding.create(\n",
+ " n_entity=biokg.n_entity,\n",
+ " n_shard=n_shard,\n",
+ " seed=seed,\n",
+ " type_offsets=np.fromiter(biokg.type_offsets.values(), dtype=np.int32),\n",
+ ")\n",
"\n",
"print(f\"Number of shards: {sharding.n_shard}\\n\")\n",
"\n",
@@ -235,8 +244,11 @@
"\n",
"print(f\"Global entity IDs on {n_shard} shards:\\n {sharding.shard_and_idx_to_entity}\\n\")\n",
"\n",
- "# If the number of entities is not divisible by n_shard, some shards will have one trailing padding entity (ID >= n_entity)\n",
- "print(f\"Number of actual (=non-padding) entities per shard:\\n {sharding.shard_counts}\\n\")\n",
+ "# If the number of entities is not divisible by n_shard,\n",
+ "# some shards will have one trailing padding entity (ID >= n_entity)\n",
+ "print(\n",
+ " f\"Number of actual (=non-padding) entities per shard:\\n {sharding.shard_counts}\\n\"\n",
+ ")\n",
"\n",
"# Entities of the same type maintain contiguous local IDs in each shard\n",
"print(f\"Type offsets per shard: \\n\", sharding.entity_type_offsets)"
@@ -280,7 +292,9 @@
}
],
"source": [
- "train_triples = PartitionedTripleSet.create_from_dataset(dataset=biokg, part=\"train\", sharding=sharding, partition_mode=\"ht_shardpair\")\n",
+ "train_triples = PartitionedTripleSet.create_from_dataset(\n",
+ " dataset=biokg, part=\"train\", sharding=sharding, partition_mode=\"ht_shardpair\"\n",
+ ")\n",
"\n",
"# train_triples.triple_counts[i,j] is the number of triples with head entity in shard i and tail entity in shard j\n",
"print(f\"Number of triples per (h,t) shard-pair:\\n {train_triples.triple_counts}\")"
@@ -316,8 +330,8 @@
"# Put original triples in the same order as train_triples.triples\n",
"triple_sorted = biokg.triples[\"train\"][train_triples.triple_sort_idx]\n",
"# Pass from global IDs to local IDs with sharding.entity_to_idx\n",
- "triple_sorted[:,0] = sharding.entity_to_idx[triple_sorted[:,0]]\n",
- "triple_sorted[:,2] = sharding.entity_to_idx[triple_sorted[:,2]]\n",
+ "triple_sorted[:, 0] = sharding.entity_to_idx[triple_sorted[:, 0]]\n",
+ "triple_sorted[:, 2] = sharding.entity_to_idx[triple_sorted[:, 2]]\n",
"# Compare with the content of train_triples.triples\n",
"np.all(triple_sorted == train_triples.triples)"
]
@@ -356,8 +370,14 @@
"metadata": {},
"outputs": [],
"source": [
- "neg_sampler = RandomShardedNegativeSampler(n_negative=1, sharding=sharding, seed=seed, corruption_scheme=\"ht\",\n",
- " local_sampling=False, flat_negative_format=False)"
+ "neg_sampler = RandomShardedNegativeSampler(\n",
+ " n_negative=1,\n",
+ " sharding=sharding,\n",
+ " seed=seed,\n",
+ " corruption_scheme=\"ht\",\n",
+ " local_sampling=False,\n",
+ " flat_negative_format=False,\n",
+ ")"
]
},
{
@@ -394,15 +414,20 @@
"# Micro-batch size, which means the number of positive triples processed on each device at each step\n",
"shard_bs = 240\n",
"\n",
- "batch_sampler = RigidShardedBatchSampler(partitioned_triple_set=train_triples, negative_sampler=neg_sampler,\n",
- " shard_bs=shard_bs, batches_per_step=device_iterations*accum_factor, seed=seed)\n",
+ "batch_sampler = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=train_triples,\n",
+ " negative_sampler=neg_sampler,\n",
+ " shard_bs=shard_bs,\n",
+ " batches_per_step=device_iterations * accum_factor,\n",
+ " seed=seed,\n",
+ ")\n",
"\n",
"\n",
"print(f\"# triples per shard-pair per step: {batch_sampler.positive_per_partition} \\n\")\n",
"\n",
"# Example batch\n",
"idx_sampler = iter(batch_sampler.get_dataloader_sampler(shuffle=True))\n",
- "for k,v in batch_sampler[next(idx_sampler)].items():\n",
+ "for k, v in batch_sampler[next(idx_sampler)].items():\n",
" print(f\"{k:<12} {str(v.shape):<30} {v.dtype};\")"
]
},
@@ -456,7 +481,9 @@
"options._popart.setPatterns(dict(RemoveAllReducePattern=True))\n",
"\n",
"# Construction similar to PyTorch dataloader\n",
- "train_dl = batch_sampler.get_dataloader(options=options, shuffle=True, num_workers=5, persistent_workers=True)"
+ "train_dl = batch_sampler.get_dataloader(\n",
+ " options=options, shuffle=True, num_workers=5, persistent_workers=True\n",
+ ")"
]
},
{
@@ -486,17 +513,23 @@
"# Loss function\n",
"logsigmoid_loss_fn = LogSigmoidLoss(margin=12.0, negative_adversarial_sampling=True)\n",
"# KGE model\n",
- "rotate_score_fn = RotatE(negative_sample_sharing=True, scoring_norm=1, sharding=sharding,\n",
- " n_relation_type=biokg.n_relation_type, embedding_size=64)\n",
+ "rotate_score_fn = RotatE(\n",
+ " negative_sample_sharing=True,\n",
+ " scoring_norm=1,\n",
+ " sharding=sharding,\n",
+ " n_relation_type=biokg.n_relation_type,\n",
+ " embedding_size=64,\n",
+ ")\n",
"# BESS wrapper\n",
- "model = EmbeddingMovingBessKGE(negative_sampler=neg_sampler, score_fn=rotate_score_fn,\n",
- " loss_fn=logsigmoid_loss_fn)\n",
+ "model = EmbeddingMovingBessKGE(\n",
+ " negative_sampler=neg_sampler, score_fn=rotate_score_fn, loss_fn=logsigmoid_loss_fn\n",
+ ")\n",
"\n",
"# Optimizer\n",
"opt = poptorch.optim.AdamW(\n",
- " model.parameters(),\n",
- " lr=0.001,\n",
- " )\n",
+ " model.parameters(),\n",
+ " lr=0.001,\n",
+ ")\n",
"\n",
"# PopTorch wrapper\n",
"poptorch_model = poptorch.trainingModel(model, options=options, optimizer=opt)\n",
@@ -504,10 +537,10 @@
"# The variable entity_embedding needs to hold different values on each replica,\n",
"# corresponding to the distinct shards of the entity embedding table\n",
"poptorch_model.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"# Compile model\n",
"batch = next(iter(train_dl))\n",
@@ -605,11 +638,20 @@
" cumulative_triples += triple_mask.numel()\n",
" res = poptorch_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})\n",
" # res[\"loss\"] contains the summed loss of elements in the last batch, for each IPU\n",
- " ep_log.append(dict(loss=float(torch.sum(res[\"loss\"])) / triple_mask[-1].numel(), step_time=(time.time()-step_start_time)))\n",
- " ep_loss = [v['loss'] for v in ep_log]\n",
- " training_loss.extend([v['loss'] for v in ep_log])\n",
- " print(f\"Epoch {ep+1} loss: {np.mean(ep_loss):.6f} --- positive triples processed: {cumulative_triples:.2e}\")\n",
- " print(f\"Epoch duration (sec): {(time.time() - ep_start_time):.5f} (average step time: {np.mean([v['step_time'] for v in ep_log]):.5f})\")\n",
+ " ep_log.append(\n",
+ " dict(\n",
+ " loss=float(torch.sum(res[\"loss\"])) / triple_mask[-1].numel(),\n",
+ " step_time=(time.time() - step_start_time),\n",
+ " )\n",
+ " )\n",
+ " ep_loss = [v[\"loss\"] for v in ep_log]\n",
+ " training_loss.extend([v[\"loss\"] for v in ep_log])\n",
+ " print(\n",
+ " f\"Epoch {ep+1} loss: {np.mean(ep_loss):.6f} --- positive triples processed: {cumulative_triples:.2e}\"\n",
+ " )\n",
+ " print(\n",
+ " f\"Epoch duration (sec): {(time.time() - ep_start_time):.5f} (average step time: {np.mean([v['step_time'] for v in ep_log]):.5f})\"\n",
+ " )\n",
"\n",
"# Plot loss as a function of the number of positive triples processed\n",
"total_triples = np.cumsum(n_epochs * len(train_dl) * [triple_mask.numel()])\n",
@@ -656,15 +698,28 @@
}
],
"source": [
- "valid_triples = PartitionedTripleSet.create_from_dataset(dataset=biokg, part=\"valid\", sharding=sharding, partition_mode=\"ht_shardpair\")\n",
- "ns_valid = TripleBasedShardedNegativeSampler(negative_heads=valid_triples.neg_heads, negative_tails=valid_triples.neg_tails,\n",
- " sharding=sharding, corruption_scheme=\"ht\", seed=seed)\n",
- "bs_valid = RigidShardedBatchSampler(partitioned_triple_set=valid_triples, negative_sampler=ns_valid, shard_bs=shard_bs, batches_per_step=10,\n",
- " seed=seed, duplicate_batch=True)\n",
+ "valid_triples = PartitionedTripleSet.create_from_dataset(\n",
+ " dataset=biokg, part=\"valid\", sharding=sharding, partition_mode=\"ht_shardpair\"\n",
+ ")\n",
+ "ns_valid = TripleBasedShardedNegativeSampler(\n",
+ " negative_heads=valid_triples.neg_heads,\n",
+ " negative_tails=valid_triples.neg_tails,\n",
+ " sharding=sharding,\n",
+ " corruption_scheme=\"ht\",\n",
+ " seed=seed,\n",
+ ")\n",
+ "bs_valid = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=valid_triples,\n",
+ " negative_sampler=ns_valid,\n",
+ " shard_bs=shard_bs,\n",
+ " batches_per_step=10,\n",
+ " seed=seed,\n",
+ " duplicate_batch=True,\n",
+ ")\n",
"\n",
"# Example batch\n",
"idx_sampler = iter(bs_valid.get_dataloader_sampler(shuffle=False))\n",
- "for k,v in bs_valid[next(idx_sampler)].items():\n",
+ "for k, v in bs_valid[next(idx_sampler)].items():\n",
" print(f\"{k:<15} {str(v.shape):<35} {v.dtype};\")"
]
},
@@ -710,15 +765,17 @@
"# With reduction=\"sum\" the returned res[\"metrics\"] has shape (batches_per_step * n_shard, n_metrics)\n",
"evaluator = Evaluation([\"mrr\", \"hits@1\", \"hits@5\", \"hits@10\"], reduction=\"sum\")\n",
"# BESS wrapper\n",
- "model_inf = ScoreMovingBessKGE(negative_sampler=ns_valid, score_fn=rotate_score_fn, evaluation=evaluator)\n",
+ "model_inf = ScoreMovingBessKGE(\n",
+ " negative_sampler=ns_valid, score_fn=rotate_score_fn, evaluation=evaluator\n",
+ ")\n",
"\n",
"# PopTorch wrapper\n",
"poptorch_model_inf = poptorch.inferenceModel(model_inf, options=val_options)\n",
"poptorch_model_inf.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"# Compile model\n",
"batch = next(iter(valid_dl))\n",
@@ -750,13 +807,18 @@
"n_val_queries = 0\n",
"for batch_val in valid_dl:\n",
" res = poptorch_model_inf(**{k: v.flatten(end_dim=1) for k, v in batch_val.items()})\n",
- " \n",
+ "\n",
" n_val_queries += batch_val[\"triple_mask\"].sum()\n",
" # By transposing res[\"metrics\"] we separate the outputs for the different metrics\n",
- " val_log.append({k: v.sum() for k, v in zip(\n",
- " evaluator.metrics.keys(),\n",
- " res[\"metrics\"].T,\n",
- " )})\n",
+ " val_log.append(\n",
+ " {\n",
+ " k: v.sum()\n",
+ " for k, v in zip(\n",
+ " evaluator.metrics.keys(),\n",
+ " res[\"metrics\"].T,\n",
+ " )\n",
+ " }\n",
+ " )\n",
"\n",
"for metric in val_log[0].keys():\n",
" reduced_metric = sum([l[metric] for l in val_log]) / n_val_queries\n",
@@ -835,16 +897,23 @@
"source": [
"n_shard = 1\n",
"\n",
- "sharding = Sharding.create(n_entity=biokg.n_entity, n_shard=n_shard, seed=seed, type_offsets=np.fromiter(biokg.type_offsets.values(), dtype=np.int32))\n",
+ "sharding = Sharding.create(\n",
+ " n_entity=biokg.n_entity,\n",
+ " n_shard=n_shard,\n",
+ " seed=seed,\n",
+ " type_offsets=np.fromiter(biokg.type_offsets.values(), dtype=np.int32),\n",
+ ")\n",
"\n",
"print(f\"Number of shards: {sharding.n_shard}\\n\")\n",
"\n",
"# All entities in the KG are now on a single shard\n",
"print(f\"Number of entities in each shard: {sharding.max_entity_per_shard}\\n\")\n",
"\n",
- "train_triples = PartitionedTripleSet.create_from_dataset(dataset=biokg, part=\"train\", sharding=sharding, partition_mode=\"ht_shardpair\")\n",
+ "train_triples = PartitionedTripleSet.create_from_dataset(\n",
+ " dataset=biokg, part=\"train\", sharding=sharding, partition_mode=\"ht_shardpair\"\n",
+ ")\n",
"\n",
- "# There is now a single (h,t) shard-pair, containing all training triples \n",
+ "# There is now a single (h,t) shard-pair, containing all training triples\n",
"print(f\"Number of triples per (h,t) shard-pair:\\n {train_triples.triple_counts}\")"
]
},
@@ -862,9 +931,16 @@
"metadata": {},
"outputs": [],
"source": [
- "# Reducing the number of shards by a factor of 4, the number of negatives per triple per shard (i.e. n_negative) needs to increase by the same factor\n",
- "neg_sampler = RandomShardedNegativeSampler(n_negative=4, sharding=sharding, seed=seed, corruption_scheme=\"ht\",\n",
- " local_sampling=False, flat_negative_format=False)"
+ "# When reducing the number of shards by a factor of 4,\n",
+ "# the number of negatives per triple per shard (i.e. n_negative) needs to increase by the same factor\n",
+ "neg_sampler = RandomShardedNegativeSampler(\n",
+ " n_negative=4,\n",
+ " sharding=sharding,\n",
+ " seed=seed,\n",
+ " corruption_scheme=\"ht\",\n",
+ " local_sampling=False,\n",
+ " flat_negative_format=False,\n",
+ ")"
]
},
{
@@ -899,15 +975,20 @@
"accum_factor = 24\n",
"shard_bs = 240\n",
"\n",
- "batch_sampler = RigidShardedBatchSampler(partitioned_triple_set=train_triples, negative_sampler=neg_sampler,\n",
- " shard_bs=shard_bs, batches_per_step=device_iterations*accum_factor, seed=seed)\n",
+ "batch_sampler = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=train_triples,\n",
+ " negative_sampler=neg_sampler,\n",
+ " shard_bs=shard_bs,\n",
+ " batches_per_step=device_iterations * accum_factor,\n",
+ " seed=seed,\n",
+ ")\n",
"\n",
"\n",
"print(f\"# triples per shard-pair per step: {batch_sampler.positive_per_partition} \\n\")\n",
"\n",
"# Example batch\n",
"idx_sampler = iter(batch_sampler.get_dataloader_sampler(shuffle=True))\n",
- "for k,v in batch_sampler[next(idx_sampler)].items():\n",
+ "for k, v in batch_sampler[next(idx_sampler)].items():\n",
" print(f\"{k:<12} {str(v.shape):<30} {v.dtype};\")"
]
},
@@ -937,26 +1018,34 @@
"options.deviceIterations(device_iterations)\n",
"options.Training.gradientAccumulation(accum_factor)\n",
"\n",
- "train_dl = batch_sampler.get_dataloader(options=options, shuffle=True, num_workers=5, persistent_workers=True)\n",
+ "train_dl = batch_sampler.get_dataloader(\n",
+ " options=options, shuffle=True, num_workers=5, persistent_workers=True\n",
+ ")\n",
"\n",
"logsigmoid_loss_fn = LogSigmoidLoss(margin=12.0, negative_adversarial_sampling=True)\n",
- "rotate_score_fn = RotatE(negative_sample_sharing=True, scoring_norm=1, sharding=sharding,\n",
- " n_relation_type=biokg.n_relation_type, embedding_size=64)\n",
+ "rotate_score_fn = RotatE(\n",
+ " negative_sample_sharing=True,\n",
+ " scoring_norm=1,\n",
+ " sharding=sharding,\n",
+ " n_relation_type=biokg.n_relation_type,\n",
+ " embedding_size=64,\n",
+ ")\n",
"\n",
- "model = EmbeddingMovingBessKGE(negative_sampler=neg_sampler, score_fn=rotate_score_fn,\n",
- " loss_fn=logsigmoid_loss_fn)\n",
+ "model = EmbeddingMovingBessKGE(\n",
+ " negative_sampler=neg_sampler, score_fn=rotate_score_fn, loss_fn=logsigmoid_loss_fn\n",
+ ")\n",
"\n",
"opt = poptorch.optim.AdamW(\n",
- " model.parameters(),\n",
- " lr=0.001,\n",
- " )\n",
+ " model.parameters(),\n",
+ " lr=0.001,\n",
+ ")\n",
"\n",
"poptorch_model = poptorch.trainingModel(model, options=options, optimizer=opt)\n",
"poptorch_model.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"\n",
"# Compile model\n",
@@ -1011,11 +1100,20 @@
" triple_mask = batch.pop(\"triple_mask\")\n",
" cumulative_triples += triple_mask.numel()\n",
" res = poptorch_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})\n",
- " ep_log.append(dict(loss=float(torch.sum(res[\"loss\"])) / triple_mask[-1].numel(), step_time=(time.time()-step_start_time)))\n",
- " ep_loss = [v['loss'] for v in ep_log]\n",
- " training_loss.extend([v['loss'] for v in ep_log])\n",
- " print(f\"Epoch {ep+1} loss: {np.mean(ep_loss):.6f} --- positive triples processed: {cumulative_triples:.2e}\")\n",
- " print(f\"Epoch duration (sec): {(time.time() - ep_start_time):.5f} (average step time: {np.mean([v['step_time'] for v in ep_log]):.5f})\")\n",
+ " ep_log.append(\n",
+ " dict(\n",
+ " loss=float(torch.sum(res[\"loss\"])) / triple_mask[-1].numel(),\n",
+ " step_time=(time.time() - step_start_time),\n",
+ " )\n",
+ " )\n",
+ " ep_loss = [v[\"loss\"] for v in ep_log]\n",
+ " training_loss.extend([v[\"loss\"] for v in ep_log])\n",
+ " print(\n",
+ " f\"Epoch {ep+1} loss: {np.mean(ep_loss):.6f} --- positive triples processed: {cumulative_triples:.2e}\"\n",
+ " )\n",
+ " print(\n",
+ " f\"Epoch duration (sec): {(time.time() - ep_start_time):.5f} (average step time: {np.mean([v['step_time'] for v in ep_log]):.5f})\"\n",
+ " )\n",
"\n",
"poptorch_model.detachFromDevice()\n",
"del train_dl"
@@ -1037,7 +1135,7 @@
"## Conclusions and next steps\n",
"\n",
"To recap, these are the basic steps to run distributed training/inference with BESS-KGE:\n",
- "* wrap your knowledge graph dataset with `besskge.dataset.KGDataset`;\n",
+ "* wrap your knowledge graph dataset with `besskge.dataset.KGDataset` (see also [Using BESS-KGE with your own data](0_custom_KG_dataset.ipynb));\n",
"* shard entities in the graph based on the number of IPUs you want to use, by using `besskge.sharding.Sharding`, and partition triples accordingly with `besskge.sharding.PartitionedTripleSet`;\n",
"* select a negative sampler from `besskge.negative_sampler` to sample the entities used to corrupt positive triples;\n",
"* use a batch sampler from `besskge.batch_sampler` to create the batch dataloader; \n",
diff --git a/notebooks/2_yago_topk_prediction.ipynb b/notebooks/2_yago_topk_prediction.ipynb
index e1b4d62..ff9e0ef 100644
--- a/notebooks/2_yago_topk_prediction.ipynb
+++ b/notebooks/2_yago_topk_prediction.ipynb
@@ -153,12 +153,16 @@
"\n",
"print(f\"Number of entities: {yago.n_entity:,}\\n\")\n",
"print(f\"Number of relation types: {yago.n_relation_type}\\n\")\n",
- "print(f\"Number of triples: \\n training: {yago.triples['train'].shape[0]:,} \\n validation/test: {yago.triples['validation'].shape[0]:,}\\n\")\n",
+ "print(\n",
+ " f\"Number of triples: \\n training: {yago.triples['train'].shape[0]:,} \\n validation/test: {yago.triples['valid'].shape[0]:,}\\n\"\n",
+ ")\n",
"\n",
"# Print example triple retrieving labels from yago.entity_dict and yago.relation_dict\n",
"ex_triple_id = 2500\n",
"ex_triple = yago.triples[\"train\"][ex_triple_id]\n",
- "print(f'Example triple: {yago.entity_dict[ex_triple[0]], yago.relation_dict[ex_triple[1]], yago.entity_dict[ex_triple[2]]}')"
+ "print(\n",
+ " f\"Example triple: {yago.entity_dict[ex_triple[0]], yago.relation_dict[ex_triple[1]], yago.entity_dict[ex_triple[2]]}\"\n",
+ ")"
]
},
{
@@ -205,7 +209,9 @@
"\n",
"# The global entity IDs can be recovered, as a function of the shard ID and the local ID on the shard, by\n",
"print(\"\\nReconstructed global entity IDs:\")\n",
- "print(sharding.shard_and_idx_to_entity[sharding.entity_to_shard, sharding.entity_to_idx])\n",
+ "print(\n",
+ " sharding.shard_and_idx_to_entity[sharding.entity_to_shard, sharding.entity_to_idx]\n",
+ ")\n",
"\n",
"train_triples = PartitionedTripleSet.create_from_dataset(yago, \"train\", sharding)\n",
"\n",
@@ -237,10 +243,21 @@
"device_iterations = 20\n",
"accum_factor = 2\n",
"shard_bs = 720\n",
- "neg_sampler = RandomShardedNegativeSampler(n_negative=1, sharding=sharding, seed=seed, corruption_scheme=\"ht\",\n",
- " local_sampling=False, flat_negative_format=False)\n",
- "bs = RigidShardedBatchSampler(partitioned_triple_set=train_triples, negative_sampler=neg_sampler, shard_bs=shard_bs,\n",
- " batches_per_step=device_iterations*accum_factor, seed=seed)"
+ "neg_sampler = RandomShardedNegativeSampler(\n",
+ " n_negative=1,\n",
+ " sharding=sharding,\n",
+ " seed=seed,\n",
+ " corruption_scheme=\"ht\",\n",
+ " local_sampling=False,\n",
+ " flat_negative_format=False,\n",
+ ")\n",
+ "bs = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=train_triples,\n",
+ " negative_sampler=neg_sampler,\n",
+ " shard_bs=shard_bs,\n",
+ " batches_per_step=device_iterations * accum_factor,\n",
+ " seed=seed,\n",
+ ")"
]
},
{
@@ -268,11 +285,13 @@
"options._popart.setPatterns(dict(RemoveAllReducePattern=True))\n",
"\n",
"# Construct the dataloader with the dedicated utility function\n",
- "train_dl = bs.get_dataloader(options=options, shuffle=True, num_workers=5, persistent_workers=True)\n",
+ "train_dl = bs.get_dataloader(\n",
+ " options=options, shuffle=True, num_workers=5, persistent_workers=True\n",
+ ")\n",
"\n",
"# Example batch\n",
"batch = next(iter(train_dl))\n",
- "for k,v in batch.items():\n",
+ "for k, v in batch.items():\n",
" print(f\"{k:<12} {str(v.shape):<30}\")"
]
},
@@ -301,9 +320,15 @@
],
"source": [
"loss_fn = LogSigmoidLoss(margin=12.0, negative_adversarial_sampling=True)\n",
- "complex_score_fn = ComplEx(negative_sample_sharing=True, sharding=sharding, n_relation_type=yago.n_relation_type, embedding_size=128)\n",
- "model = EmbeddingMovingBessKGE(negative_sampler=neg_sampler, score_fn=complex_score_fn,\n",
- " loss_fn=loss_fn)\n",
+ "complex_score_fn = ComplEx(\n",
+ " negative_sample_sharing=True,\n",
+ " sharding=sharding,\n",
+ " n_relation_type=yago.n_relation_type,\n",
+ " embedding_size=128,\n",
+ ")\n",
+ "model = EmbeddingMovingBessKGE(\n",
+ " negative_sampler=neg_sampler, score_fn=complex_score_fn, loss_fn=loss_fn\n",
+ ")\n",
"\n",
"print(f\"# model parameters: {model.n_embedding_parameters:,}\")"
]
@@ -326,19 +351,35 @@
"source": [
"def evaluate_mrr_cpu(triples, evaluation):\n",
" # Unshard entity embedding table\n",
- " ent_table = complex_score_fn.entity_embedding.detach()[sharding.entity_to_shard, sharding.entity_to_idx]\n",
+ " ent_table = complex_score_fn.entity_embedding.detach()[\n",
+ " sharding.entity_to_shard, sharding.entity_to_idx\n",
+ " ]\n",
"\n",
" # Score query (h,r,?) against all entities in the knowledge graph and select top-10 scores\n",
- " scores = complex_score_fn.score_tails(ent_table[triples[:,0]], torch.from_numpy(triples[:,1]), ent_table.unsqueeze(0))\n",
+ " scores = complex_score_fn.score_tails(\n",
+ " ent_table[triples[:, 0]],\n",
+ " torch.from_numpy(triples[:, 1]),\n",
+ " ent_table.unsqueeze(0),\n",
+ " )\n",
" top_k = torch.topk(scores, dim=-1, k=10)\n",
"\n",
" # Use evaluation.ranks_from_indices to rank the ground truth, if present, among the predictions\n",
- " ranks = evaluation.ranks_from_indices(torch.from_numpy(triples[:,2]), top_k.indices.squeeze())\n",
- " return {k: v / triples.shape[0] for k, v in evaluation.dict_metrics_from_ranks(ranks).items()}\n",
+ " ranks = evaluation.ranks_from_indices(\n",
+ " torch.from_numpy(triples[:, 2]), top_k.indices.squeeze()\n",
+ " )\n",
+ " return {\n",
+ " k: v / triples.shape[0]\n",
+ " for k, v in evaluation.dict_metrics_from_ranks(ranks).items()\n",
+ " }\n",
+ "\n",
"\n",
"# Sample validation queries\n",
"n_val_triples = 500\n",
- "val_triple_subset = yago.triples[\"validation\"][np.random.default_rng(seed=1000).choice(yago.triples[\"validation\"].shape[0], n_val_triples)]\n",
+ "val_triple_subset = yago.triples[\"valid\"][\n",
+ " np.random.default_rng(seed=1000).choice(\n",
+ " yago.triples[\"valid\"].shape[0], n_val_triples\n",
+ " )\n",
+ "]\n",
"evaluation = Evaluation([\"mrr\"], worst_rank_infty=True, reduction=\"sum\")"
]
},
@@ -357,19 +398,19 @@
],
"source": [
"opt = poptorch.optim.AdamW(\n",
- " model.parameters(),\n",
- " lr=0.0016,\n",
- " )\n",
+ " model.parameters(),\n",
+ " lr=0.0016,\n",
+ ")\n",
"\n",
"poptorch_model = poptorch.trainingModel(model, options=options, optimizer=opt)\n",
"\n",
"# The variable entity_embedding needs to hold different values on each replica,\n",
"# corresponding to the shards of the entity embedding table\n",
"poptorch_model.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"# Graph compilation\n",
"_ = batch.pop(\"triple_mask\")\n",
@@ -471,11 +512,20 @@
" triple_mask = batch.pop(\"triple_mask\")\n",
" cumulative_triples += triple_mask.numel()\n",
" res = poptorch_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})\n",
- " ep_log.append(dict(loss=float(torch.sum(res[\"loss\"])) / triple_mask[-1].numel(), step_time=(time.time()-step_start_time)))\n",
- " ep_loss = [v['loss'] for v in ep_log]\n",
- " training_loss.extend([v['loss'] for v in ep_log])\n",
- " print(f\"Epoch {ep+1} loss: {np.mean(ep_loss):.6f} --- positive triples processed: {cumulative_triples:.2e}\")\n",
- " print(f\"Epoch duration (sec): {(time.time() - ep_start_time):.5f} (average step time: {np.mean([v['step_time'] for v in ep_log]):.5f})\")\n",
+ " ep_log.append(\n",
+ " dict(\n",
+ " loss=float(torch.sum(res[\"loss\"])) / triple_mask[-1].numel(),\n",
+ " step_time=(time.time() - step_start_time),\n",
+ " )\n",
+ " )\n",
+ " ep_loss = [v[\"loss\"] for v in ep_log]\n",
+ " training_loss.extend([v[\"loss\"] for v in ep_log])\n",
+ " print(\n",
+ " f\"Epoch {ep+1} loss: {np.mean(ep_loss):.6f} --- positive triples processed: {cumulative_triples:.2e}\"\n",
+ " )\n",
+ " print(\n",
+ " f\"Epoch duration (sec): {(time.time() - ep_start_time):.5f} (average step time: {np.mean([v['step_time'] for v in ep_log]):.5f})\"\n",
+ " )\n",
" if ep % val_ep_interval == 0:\n",
" ep_mrr = evaluate_mrr_cpu(val_triple_subset, evaluation)[\"mrr\"]\n",
" val_mrr.append(ep_mrr)\n",
@@ -487,9 +537,14 @@
"# Plot loss and sample MRR as a function of the number of positive triples processed\n",
"total_triples = np.cumsum(n_epochs * len(train_dl) * [triple_mask.numel()])\n",
"ax0, ax1 = plt.gca(), plt.twinx()\n",
- "line0, = ax0.plot(total_triples, training_loss)\n",
- "line1, = ax1.plot(np.concatenate([total_triples[::val_ep_interval * len(train_dl)], total_triples[-1:]]),\n",
- " val_mrr, color=\"r\")\n",
+ "(line0,) = ax0.plot(total_triples, training_loss)\n",
+ "(line1,) = ax1.plot(\n",
+ " np.concatenate(\n",
+ " [total_triples[:: val_ep_interval * len(train_dl)], total_triples[-1:]]\n",
+ " ),\n",
+ " val_mrr,\n",
+ " color=\"r\",\n",
+ ")\n",
"ax0.set_xlabel(\"Positive triples\")\n",
"ax0.set_ylabel(\"Loss\")\n",
"ax1.set_ylabel(\"Sample MRR\")\n",
@@ -539,10 +594,19 @@
"device_iterations = 1\n",
"shard_bs = 1440\n",
"\n",
- "validation_triples = PartitionedTripleSet.create_from_dataset(yago, \"validation\", sharding, partition_mode=\"h_shard\")\n",
+ "validation_triples = PartitionedTripleSet.create_from_dataset(\n",
+ " yago, \"valid\", sharding, partition_mode=\"h_shard\"\n",
+ ")\n",
"candidate_sampler = PlaceholderNegativeSampler(corruption_scheme=\"t\", seed=seed)\n",
- "bs_valid = RigidShardedBatchSampler(partitioned_triple_set=validation_triples, negative_sampler=candidate_sampler, shard_bs=shard_bs, batches_per_step=device_iterations,\n",
- " seed=seed, duplicate_batch=False, return_triple_idx=True)\n",
+ "bs_valid = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=validation_triples,\n",
+ " negative_sampler=candidate_sampler,\n",
+ " shard_bs=shard_bs,\n",
+ " batches_per_step=device_iterations,\n",
+ " seed=seed,\n",
+ " duplicate_batch=False,\n",
+ " return_triple_idx=True,\n",
+ ")\n",
"\n",
"print(\"Number of triples per h_shard:\")\n",
"print(validation_triples.triple_counts)"
@@ -571,11 +635,13 @@
"val_options.deviceIterations(bs_valid.batches_per_step)\n",
"val_options.outputMode(poptorch.OutputMode.All)\n",
"\n",
- "valid_dl = bs_valid.get_dataloader(options=val_options, shuffle=False, num_workers=5, persistent_workers=True)\n",
+ "valid_dl = bs_valid.get_dataloader(\n",
+ " options=val_options, shuffle=False, num_workers=5, persistent_workers=True\n",
+ ")\n",
"\n",
"# Example batch\n",
"batch = next(iter(valid_dl))\n",
- "for k,v in batch.items():\n",
+ "for k, v in batch.items():\n",
" print(f\"{k:<12} {str(v.shape):<30}\")"
]
},
@@ -607,9 +673,9 @@
],
"source": [
"# Put original validation triple in the same order as validation_triples.triples\n",
- "triple_sorted = yago.triples[\"validation\"][validation_triples.triple_sort_idx]\n",
+ "triple_sorted = yago.triples[\"valid\"][validation_triples.triple_sort_idx]\n",
"# Pass from global IDs to local IDs just for the heads\n",
- "triple_sorted[:,0] = sharding.entity_to_idx[triple_sorted[:,0]]\n",
+ "triple_sorted[:, 0] = sharding.entity_to_idx[triple_sorted[:, 0]]\n",
"# Compare with validation_triples.triples\n",
"np.all(triple_sorted == validation_triples.triples)"
]
@@ -639,16 +705,24 @@
}
],
"source": [
- "evaluation = Evaluation([\"mrr\", \"hits@3\", \"hits@10\"], worst_rank_infty=True, reduction=\"sum\")\n",
- "inf_model = TopKQueryBessKGE(k=10, candidate_sampler=candidate_sampler, score_fn=complex_score_fn, evaluation=evaluation, window_size=500)\n",
+ "evaluation = Evaluation(\n",
+ " [\"mrr\", \"hits@3\", \"hits@10\"], worst_rank_infty=True, reduction=\"sum\"\n",
+ ")\n",
+ "inf_model = TopKQueryBessKGE(\n",
+ " k=10,\n",
+ " candidate_sampler=candidate_sampler,\n",
+ " score_fn=complex_score_fn,\n",
+ " evaluation=evaluation,\n",
+ " window_size=500,\n",
+ ")\n",
"\n",
"poptorch_inf_model = poptorch.inferenceModel(inf_model, options=val_options)\n",
"\n",
"poptorch_inf_model.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"_ = batch.pop(\"triple_idx\")\n",
"res = poptorch_inf_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})"
@@ -689,10 +763,15 @@
" # triple_mask is now passed to the model to filter out the metrics of padding triples\n",
" res = poptorch_inf_model(**{k: v.flatten(end_dim=1) for k, v in batch_val.items()})\n",
" n_val_queries += batch_val[\"triple_mask\"].sum()\n",
- " val_log.append({k: v.sum() for k, v in zip(\n",
- " evaluation.metrics.keys(),\n",
- " res[\"metrics\"].T,\n",
- " )})\n",
+ " val_log.append(\n",
+ " {\n",
+ " k: v.sum()\n",
+ " for k, v in zip(\n",
+ " evaluation.metrics.keys(),\n",
+ " res[\"metrics\"].T,\n",
+ " )\n",
+ " }\n",
+ " )\n",
"\n",
"print(f\"Validation time (sec): {(time.time() - start_time):.5f}\\n\")\n",
"\n",
@@ -729,7 +808,7 @@
],
"source": [
"start_time = time.time()\n",
- "cpu_res = evaluate_mrr_cpu(yago.triples[\"validation\"], evaluation)\n",
+ "cpu_res = evaluate_mrr_cpu(yago.triples[\"valid\"], evaluation)\n",
"\n",
"print(f\"CPU Validation time (sec): {(time.time() - start_time):.5f}\\n\")\n",
"print(f\"CPU validation MRR: {cpu_res['mrr']:.6f}\")"
@@ -792,18 +871,21 @@
"source": [
"def check_prediction(val_triple_id):\n",
" # Recover the non-padding triples seen in the last batch using triple_idx and triple_mask\n",
- " triples = yago.triples[\"validation\"][validation_triples.triple_sort_idx][triple_idx[batch_val[\"triple_mask\"]]]\n",
- " h,r,t = triples[val_triple_id]\n",
+ " triples = yago.triples[\"valid\"][validation_triples.triple_sort_idx][\n",
+ " triple_idx[batch_val[\"triple_mask\"]]\n",
+ " ]\n",
+ " h, r, t = triples[val_triple_id]\n",
" # res[\"topk_global_id\"] contains the top-10 tails predicted by the KGE model\n",
" top10_t = res[\"topk_global_id\"][batch_val[\"triple_mask\"].flatten()][val_triple_id]\n",
- " \n",
- " print(f'Example query: ({yago.entity_dict[h]}, {yago.relation_dict[r]}, ?)\\n')\n",
+ "\n",
+ " print(f\"Example query: ({yago.entity_dict[h]}, {yago.relation_dict[r]}, ?)\\n\")\n",
" print(f\"Correct tail: {yago.entity_dict[t]}\\n\")\n",
" print(f\"10 most likely predicted tails:\")\n",
" for i, pt in enumerate(top10_t):\n",
" print(f\"{i+1}) {yago.entity_dict[pt]}\" + (\" <-----\" if pt == t else \"\"))\n",
" print(\"\\n\")\n",
"\n",
+ "\n",
"check_prediction(10)\n",
"check_prediction(1000)"
]
@@ -835,7 +917,9 @@
"source": [
"# complex_score_fn.entity_embedding has shape [n_shard, max_entity_per_shard, embedding_size]\n",
"\n",
- "print(f\"Current embedding table ({sharding.n_shard} shards): {complex_score_fn.entity_embedding.shape}\")\n",
+ "print(\n",
+ " f\"Current embedding table ({sharding.n_shard} shards): {complex_score_fn.entity_embedding.shape}\"\n",
+ ")\n",
"\n",
"# New entity sharding with a single shard - to use on 1 IPU\n",
"new_val_sharding = Sharding.create(yago.n_entity, n_shard=1, seed=seed)\n",
@@ -843,7 +927,9 @@
"# Update sharding of embedding tables, stored in the scoring function\n",
"complex_score_fn.update_sharding(new_sharding=new_val_sharding)\n",
"\n",
- "print(f\"Refactored embedding table (1 shard): {complex_score_fn.entity_embedding.shape}\")"
+ "print(\n",
+ " f\"Refactored embedding table (1 shard): {complex_score_fn.entity_embedding.shape}\"\n",
+ ")"
]
},
{
@@ -879,9 +965,18 @@
"device_iterations = 4\n",
"shard_bs = 1440\n",
"\n",
- "validation_triples = PartitionedTripleSet.create_from_dataset(yago, \"validation\", new_val_sharding, partition_mode=\"h_shard\")\n",
- "bs_valid = RigidShardedBatchSampler(partitioned_triple_set=validation_triples, negative_sampler=candidate_sampler, shard_bs=shard_bs, batches_per_step=device_iterations,\n",
- " seed=seed, duplicate_batch=False, return_triple_idx=True)\n",
+ "validation_triples = PartitionedTripleSet.create_from_dataset(\n",
+ " yago, \"valid\", new_val_sharding, partition_mode=\"h_shard\"\n",
+ ")\n",
+ "bs_valid = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=validation_triples,\n",
+ " negative_sampler=candidate_sampler,\n",
+ " shard_bs=shard_bs,\n",
+ " batches_per_step=device_iterations,\n",
+ " seed=seed,\n",
+ " duplicate_batch=False,\n",
+ " return_triple_idx=True,\n",
+ ")\n",
"\n",
"print(\"Number of triples per h_shard:\")\n",
"print(validation_triples.triple_counts)\n",
@@ -891,11 +986,13 @@
"val_options.deviceIterations(bs_valid.batches_per_step)\n",
"val_options.outputMode(poptorch.OutputMode.All)\n",
"\n",
- "valid_dl = bs_valid.get_dataloader(options=val_options, shuffle=False, num_workers=5, persistent_workers=True)\n",
+ "valid_dl = bs_valid.get_dataloader(\n",
+ " options=val_options, shuffle=False, num_workers=5, persistent_workers=True\n",
+ ")\n",
"\n",
"print(\"Example batch:\")\n",
"batch = next(iter(valid_dl))\n",
- "for k,v in batch.items():\n",
+ "for k, v in batch.items():\n",
" print(f\"{k:<12} {str(v.shape):<30}\")"
]
},
@@ -927,15 +1024,21 @@
}
],
"source": [
- "inf_model = TopKQueryBessKGE(k=10, candidate_sampler=candidate_sampler, score_fn=complex_score_fn, evaluation=evaluation, window_size=500)\n",
+ "inf_model = TopKQueryBessKGE(\n",
+ " k=10,\n",
+ " candidate_sampler=candidate_sampler,\n",
+ " score_fn=complex_score_fn,\n",
+ " evaluation=evaluation,\n",
+ " window_size=500,\n",
+ ")\n",
"\n",
"poptorch_inf_model = poptorch.inferenceModel(inf_model, options=val_options)\n",
"\n",
"poptorch_inf_model.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"# Compile model\n",
"_ = batch.pop(\"triple_idx\")\n",
@@ -949,12 +1052,17 @@
" triple_idx = batch_val.pop(\"triple_idx\")\n",
" step_start_time = time.time()\n",
" res = poptorch_inf_model(**{k: v.flatten(end_dim=1) for k, v in batch_val.items()})\n",
- " \n",
+ "\n",
" n_val_queries += batch_val[\"triple_mask\"].sum()\n",
- " val_log.append({k: v.sum() for k, v in zip(\n",
- " evaluation.metrics.keys(),\n",
- " res[\"metrics\"].T,\n",
- " )})\n",
+ " val_log.append(\n",
+ " {\n",
+ " k: v.sum()\n",
+ " for k, v in zip(\n",
+ " evaluation.metrics.keys(),\n",
+ " res[\"metrics\"].T,\n",
+ " )\n",
+ " }\n",
+ " )\n",
"\n",
"print(f\"Validation time (sec): {(time.time() - start_time):.5f}\\n\")\n",
"\n",
diff --git a/notebooks/3_wikikg2_fp16.ipynb b/notebooks/3_wikikg2_fp16.ipynb
index 3b2306e..26de2a0 100644
--- a/notebooks/3_wikikg2_fp16.ipynb
+++ b/notebooks/3_wikikg2_fp16.ipynb
@@ -118,7 +118,7 @@
"source": [
"## Sharding entities and triples\n",
"\n",
- "The OGBL-WikiKG2 dataset can be downloaded and preprocessed with the built-in method of `KGDataset`, `build_wikikg2`. Sharding of entities and triples is performed as shown in the [KGE Training and Inference on OGBL-BioKG](1_biokg_training_inference.ipynb) notebook."
+ "The OGBL-WikiKG2 dataset can be downloaded and preprocessed with the built-in method of `KGDataset`, `build_ogbl_wikikg2`. Sharding of entities and triples is performed as shown in the [KGE Training and Inference on OGBL-BioKG](1_biokg_training_inference.ipynb) notebook."
]
},
{
@@ -143,12 +143,16 @@
}
],
"source": [
- "wikikg = KGDataset.build_wikikg2(root=pathlib.Path(dataset_directory))\n",
+ "wikikg = KGDataset.build_ogbl_wikikg2(root=pathlib.Path(dataset_directory))\n",
"\n",
"print(f\"Number of entities: {wikikg.n_entity:,}\\n\")\n",
"print(f\"Number of relation types: {wikikg.n_relation_type}\\n\")\n",
- "print(f\"Number of triples: \\n training: {wikikg.triples['train'].shape[0]:,} \\n validation/test: {wikikg.triples['valid'].shape[0]:,}\\n\")\n",
- "print(f\"Number of negative heads/tails for validation/test triples: {wikikg.neg_heads['valid'].shape[-1]}\")"
+ "print(\n",
+ " f\"Number of triples: \\n training: {wikikg.triples['train'].shape[0]:,} \\n validation/test: {wikikg.triples['valid'].shape[0]:,}\\n\"\n",
+ ")\n",
+ "print(\n",
+ " f\"Number of negative heads/tails for validation/test triples: {wikikg.neg_heads['valid'].shape[-1]}\"\n",
+ ")"
]
},
{
@@ -189,7 +193,7 @@
"\n",
"print(f\"Number of entities in each shard: {sharding.max_entity_per_shard:,}\\n\")\n",
"\n",
- "print(f\"Global entity IDs on {n_shard} shards:\\n {sharding.shard_and_idx_to_entity}\\n\")"
+ "print(f\"Global entity IDs on {n_shard} shards:\\n {sharding.shard_and_idx_to_entity}\\n\")\n"
]
},
{
@@ -210,7 +214,9 @@
}
],
"source": [
- "train_triples = PartitionedTripleSet.create_from_dataset(dataset=wikikg, part=\"train\", sharding=sharding, partition_mode=\"ht_shardpair\")\n",
+ "train_triples = PartitionedTripleSet.create_from_dataset(\n",
+ " dataset=wikikg, part=\"train\", sharding=sharding, partition_mode=\"ht_shardpair\"\n",
+ ")\n",
"\n",
"print(f\"Number of triples per (h,t) shard-pair:\\n {train_triples.triple_counts}\")"
]
@@ -250,18 +256,29 @@
"accum_factor = 1\n",
"shard_bs = 512\n",
"\n",
- "neg_sampler = RandomShardedNegativeSampler(n_negative=32, sharding=sharding, seed=seed, corruption_scheme=\"t\",\n",
- " local_sampling=False, flat_negative_format=True)\n",
+ "neg_sampler = RandomShardedNegativeSampler(\n",
+ " n_negative=32,\n",
+ " sharding=sharding,\n",
+ " seed=seed,\n",
+ " corruption_scheme=\"t\",\n",
+ " local_sampling=False,\n",
+ " flat_negative_format=True,\n",
+ ")\n",
"\n",
- "batch_sampler = RandomShardedBatchSampler(partitioned_triple_set=train_triples, negative_sampler=neg_sampler,\n",
- " shard_bs=shard_bs, batches_per_step=device_iterations*accum_factor, seed=seed)\n",
+ "batch_sampler = RandomShardedBatchSampler(\n",
+ " partitioned_triple_set=train_triples,\n",
+ " negative_sampler=neg_sampler,\n",
+ " shard_bs=shard_bs,\n",
+ " batches_per_step=device_iterations * accum_factor,\n",
+ " seed=seed,\n",
+ ")\n",
"\n",
"\n",
"print(f\"# triples per shard-pair per step: {batch_sampler.positive_per_partition} \\n\")\n",
"\n",
"# Example batch\n",
"idx_sampler = iter(batch_sampler.get_dataloader_sampler(shuffle=True))\n",
- "for k,v in batch_sampler[next(idx_sampler)].items():\n",
+ "for k, v in batch_sampler[next(idx_sampler)].items():\n",
" print(f\"{k:<12} {str(v.shape):<30} {v.dtype};\")"
]
},
@@ -289,7 +306,9 @@
"# Enable stochastic rounding on IPU for more stable half-precision training\n",
"options.Precision.enableStochasticRounding(True)\n",
"\n",
- "train_dl = batch_sampler.get_dataloader(options=options, shuffle=True, num_workers=3, persistent_workers=True)"
+ "train_dl = batch_sampler.get_dataloader(\n",
+ " options=options, shuffle=True, num_workers=3, persistent_workers=True\n",
+ ")"
]
},
{
@@ -321,12 +340,22 @@
"loss_fn = SampledSoftmaxCrossEntropyLoss(n_entity=wikikg.n_entity)\n",
"# Initializer for entity and relation embeddings\n",
"emb_initializer = [init_KGE_normal]\n",
- "transe_score_fn = TransE(negative_sample_sharing=True, scoring_norm=1, sharding=sharding,\n",
- " n_relation_type=wikikg.n_relation_type, embedding_size=100,\n",
- " entity_initializer=emb_initializer, relation_initializer=emb_initializer)\n",
+ "transe_score_fn = TransE(\n",
+ " negative_sample_sharing=True,\n",
+ " scoring_norm=1,\n",
+ " sharding=sharding,\n",
+ " n_relation_type=wikikg.n_relation_type,\n",
+ " embedding_size=100,\n",
+ " entity_initializer=emb_initializer,\n",
+ " relation_initializer=emb_initializer,\n",
+ ")\n",
"\n",
- "model = EmbeddingMovingBessKGE(negative_sampler=neg_sampler, score_fn=transe_score_fn,\n",
- " loss_fn=loss_fn, augment_negative=True)\n",
+ "model = EmbeddingMovingBessKGE(\n",
+ " negative_sampler=neg_sampler,\n",
+ " score_fn=transe_score_fn,\n",
+ " loss_fn=loss_fn,\n",
+ " augment_negative=True,\n",
+ ")\n",
"\n",
"print(f\"# model parameters: {model.n_embedding_parameters:,}\")"
]
@@ -360,21 +389,21 @@
"model.half()\n",
"\n",
"opt = poptorch.optim.SGD(\n",
- " model.parameters(),\n",
- " lr=0.001,\n",
- " momentum=0.95,\n",
- " velocity_accum_type=torch.float16,\n",
- " )\n",
+ " model.parameters(),\n",
+ " lr=0.001,\n",
+ " momentum=0.95,\n",
+ " velocity_accum_type=torch.float16,\n",
+ ")\n",
"\n",
"poptorch_model = poptorch.trainingModel(model, options=options, optimizer=opt)\n",
"\n",
"# The variable entity_embedding needs to hold different values on each replica,\n",
"# corresponding to the shards of the entity embedding table\n",
"poptorch_model.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"# Compile model\n",
"batch = next(iter(train_dl))\n",
@@ -413,16 +442,32 @@
"n_sample_queries = 4000\n",
"\n",
"val_device_iterations = 2\n",
- "val_shard_bs = 512 \n",
+ "val_shard_bs = 512\n",
"\n",
"# Partition a random subset of n_sample_queries triples taken from wikikg.triples[\"valid\"]\n",
- "subset_val_triples = wikikg.triples[\"valid\"][np.random.default_rng(seed=seed).choice(wikikg.triples[\"valid\"].shape[0], n_sample_queries)]\n",
- "sample_val_triples = PartitionedTripleSet.create_from_queries(wikikg, sharding, queries=subset_val_triples[:,:2],\n",
- " query_mode=\"hr\", ground_truth=subset_val_triples[:,2]) \n",
+ "subset_val_triples = wikikg.triples[\"valid\"][\n",
+ " np.random.default_rng(seed=seed).choice(\n",
+ " wikikg.triples[\"valid\"].shape[0], n_sample_queries\n",
+ " )\n",
+ "]\n",
+ "sample_val_triples = PartitionedTripleSet.create_from_queries(\n",
+ " wikikg,\n",
+ " sharding,\n",
+ " queries=subset_val_triples[:, :2],\n",
+ " query_mode=\"hr\",\n",
+ " ground_truth=subset_val_triples[:, 2],\n",
+ ")\n",
"\n",
"candidate_sampler = PlaceholderNegativeSampler(corruption_scheme=\"t\", seed=seed)\n",
- "bs_sample = RigidShardedBatchSampler(partitioned_triple_set=sample_val_triples, negative_sampler=candidate_sampler, shard_bs=val_shard_bs,\n",
- " batches_per_step=val_device_iterations, seed=seed, duplicate_batch=False, return_triple_idx=False)\n",
+ "bs_sample = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=sample_val_triples,\n",
+ " negative_sampler=candidate_sampler,\n",
+ " shard_bs=val_shard_bs,\n",
+ " batches_per_step=val_device_iterations,\n",
+ " seed=seed,\n",
+ " duplicate_batch=False,\n",
+ " return_triple_idx=False,\n",
+ ")\n",
"\n",
"print(\"Number of triples per h_shard:\")\n",
"print(sample_val_triples.triple_counts)"
@@ -450,11 +495,13 @@
"val_options.deviceIterations(bs_sample.batches_per_step)\n",
"val_options.outputMode(poptorch.OutputMode.All)\n",
"\n",
- "sample_valid_dl = bs_sample.get_dataloader(options=val_options, shuffle=False, num_workers=2, persistent_workers=True)\n",
+ "sample_valid_dl = bs_sample.get_dataloader(\n",
+ " options=val_options, shuffle=False, num_workers=2, persistent_workers=True\n",
+ ")\n",
"\n",
"# Example batch\n",
"val_batch = next(iter(sample_valid_dl))\n",
- "for k,v in val_batch.items():\n",
+ "for k, v in val_batch.items():\n",
" print(f\"{k:<12} {str(v.shape):<30}\")"
]
},
@@ -479,15 +526,21 @@
"\n",
"evaluation = Evaluation([\"mrr\"], worst_rank_infty=True, reduction=\"sum\")\n",
"\n",
- "inf_model = TopKQueryBessKGE(k=10, candidate_sampler=candidate_sampler, score_fn=transe_score_fn, evaluation=evaluation, window_size=500)\n",
+ "inf_model = TopKQueryBessKGE(\n",
+ " k=10,\n",
+ " candidate_sampler=candidate_sampler,\n",
+ " score_fn=transe_score_fn,\n",
+ " evaluation=evaluation,\n",
+ " window_size=500,\n",
+ ")\n",
"\n",
"poptorch_inf_model = poptorch.inferenceModel(inf_model, options=val_options)\n",
"\n",
"poptorch_inf_model.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"# Compile inference model\n",
"val_res = poptorch_inf_model(**{k: v.flatten(end_dim=1) for k, v in val_batch.items()})\n",
@@ -693,11 +746,20 @@
" step_start_time = time.time()\n",
" cumulative_triples += batch[\"head\"].numel()\n",
" res = poptorch_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})\n",
- " ep_log.append(dict(loss= float(torch.sum(res[\"loss\"])) / batch[\"head\"][0].numel(), step_time=(time.time()-step_start_time)))\n",
- " ep_loss = [v['loss'] for v in ep_log]\n",
- " training_loss.extend([v['loss'] for v in ep_log])\n",
- " print(f\"Epoch {ep+1} loss: {np.mean(ep_loss):.6f} --- positive triples processed: {cumulative_triples:.2e}\")\n",
- " print(f\"Epoch duration (sec): {(time.time() - ep_start_time):.5f} (average step time: {np.mean([v['step_time'] for v in ep_log]):.5f})\")\n",
+ " ep_log.append(\n",
+ " dict(\n",
+ " loss=float(torch.sum(res[\"loss\"])) / batch[\"head\"][0].numel(),\n",
+ " step_time=(time.time() - step_start_time),\n",
+ " )\n",
+ " )\n",
+ " ep_loss = [v[\"loss\"] for v in ep_log]\n",
+ " training_loss.extend([v[\"loss\"] for v in ep_log])\n",
+ " print(\n",
+ " f\"Epoch {ep+1} loss: {np.mean(ep_loss):.6f} --- positive triples processed: {cumulative_triples:.2e}\"\n",
+ " )\n",
+ " print(\n",
+ " f\"Epoch duration (sec): {(time.time() - ep_start_time):.5f} (average step time: {np.mean([v['step_time'] for v in ep_log]):.5f})\"\n",
+ " )\n",
" if ep % val_ep_interval == 0:\n",
" poptorch_model.detachFromDevice()\n",
" poptorch_inf_model.attachToDevice()\n",
@@ -707,18 +769,24 @@
" val_start_time = time.time()\n",
" ep_mrr = 0.0\n",
" for batch_val in sample_valid_dl:\n",
- " ep_mrr += poptorch_inf_model(**{k: v.flatten(end_dim=1) for k, v in batch_val.items()})[\"metrics\"].sum()\n",
- " ep_mrr /= n_sample_queries \n",
+ " ep_mrr += poptorch_inf_model(\n",
+ " **{k: v.flatten(end_dim=1) for k, v in batch_val.items()}\n",
+ " )[\"metrics\"].sum()\n",
+ " ep_mrr /= n_sample_queries\n",
" val_mrr.append(ep_mrr)\n",
- " print(f\"Epoch {ep+1} sample MRR: {ep_mrr:.4f} (validation time: {(time.time() - val_start_time):.5f})\")\n",
+ " print(\n",
+ " f\"Epoch {ep+1} sample MRR: {ep_mrr:.4f} (validation time: {(time.time() - val_start_time):.5f})\"\n",
+ " )\n",
" poptorch_inf_model.detachFromDevice()\n",
" poptorch_model.attachToDevice()\n",
"\n",
"# Plot loss and sample MRR as a function of the number of positive triples processed\n",
"total_triples = np.cumsum(n_epochs * len(train_dl) * [batch[\"head\"].numel()])\n",
"ax0, ax1 = plt.gca(), plt.twinx()\n",
- "line0, = ax0.plot(total_triples, training_loss)\n",
- "line1, = ax1.plot(total_triples[::val_ep_interval * len(train_dl)], val_mrr, color=\"r\")\n",
+ "(line0,) = ax0.plot(total_triples, training_loss)\n",
+ "(line1,) = ax1.plot(\n",
+ " total_triples[:: val_ep_interval * len(train_dl)], val_mrr, color=\"r\"\n",
+ ")\n",
"ax0.set_xlabel(\"Positive triples\")\n",
"ax0.set_ylabel(\"Loss\")\n",
"ax1.set_ylabel(\"Sample MRR\")\n",
@@ -752,9 +820,16 @@
}
],
"source": [
- "validation_triples = PartitionedTripleSet.create_from_dataset(wikikg, \"valid\", sharding, partition_mode=\"h_shard\")\n",
- "bs_valid = RigidShardedBatchSampler(partitioned_triple_set=validation_triples, negative_sampler=candidate_sampler, shard_bs=val_shard_bs,\n",
- " batches_per_step=val_device_iterations, seed=seed)\n",
+ "validation_triples = PartitionedTripleSet.create_from_dataset(\n",
+ " wikikg, \"valid\", sharding, partition_mode=\"h_shard\"\n",
+ ")\n",
+ "bs_valid = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=validation_triples,\n",
+ " negative_sampler=candidate_sampler,\n",
+ " shard_bs=val_shard_bs,\n",
+ " batches_per_step=val_device_iterations,\n",
+ " seed=seed,\n",
+ ")\n",
"\n",
"print(\"Number of triples per h_shard:\")\n",
"print(validation_triples.triple_counts)\n",
@@ -784,7 +859,9 @@
"n_val_queries = 0\n",
"for batch_val in valid_dl:\n",
" n_val_queries += batch_val[\"triple_mask\"].sum()\n",
- " val_mrr += poptorch_inf_model(**{k: v.flatten(end_dim=1) for k, v in batch_val.items()})[\"metrics\"].sum()\n",
+ " val_mrr += poptorch_inf_model(\n",
+ " **{k: v.flatten(end_dim=1) for k, v in batch_val.items()}\n",
+ " )[\"metrics\"].sum()\n",
"\n",
"print(f\"Validation MRR: {val_mrr / n_val_queries}\")\n",
"print(f\"Validation time (sec): {(time.time() - start_time):.5f}\")\n",
@@ -824,16 +901,29 @@
}
],
"source": [
- "validation_triples = PartitionedTripleSet.create_from_dataset(dataset=wikikg, part=\"valid\", sharding=sharding, partition_mode=\"ht_shardpair\")\n",
- "ns_valid = TripleBasedShardedNegativeSampler(negative_heads=validation_triples.neg_heads, negative_tails=validation_triples.neg_tails,\n",
- " sharding=sharding, corruption_scheme=\"t\", seed=seed)\n",
+ "validation_triples = PartitionedTripleSet.create_from_dataset(\n",
+ " dataset=wikikg, part=\"valid\", sharding=sharding, partition_mode=\"ht_shardpair\"\n",
+ ")\n",
+ "ns_valid = TripleBasedShardedNegativeSampler(\n",
+ " negative_heads=validation_triples.neg_heads,\n",
+ " negative_tails=validation_triples.neg_tails,\n",
+ " sharding=sharding,\n",
+ " corruption_scheme=\"t\",\n",
+ " seed=seed,\n",
+ ")\n",
"# We do not need to duplicate_batch as we only want to score negative tails\n",
- "bs_valid = RigidShardedBatchSampler(partitioned_triple_set=validation_triples, negative_sampler=ns_valid, shard_bs=256, batches_per_step=10,\n",
- " seed=seed, duplicate_batch=False)\n",
+ "bs_valid = RigidShardedBatchSampler(\n",
+ " partitioned_triple_set=validation_triples,\n",
+ " negative_sampler=ns_valid,\n",
+ " shard_bs=256,\n",
+ " batches_per_step=10,\n",
+ " seed=seed,\n",
+ " duplicate_batch=False,\n",
+ ")\n",
"\n",
"# Example batch\n",
"idx_sampler = iter(bs_valid.get_dataloader_sampler(shuffle=False))\n",
- "for k,v in bs_valid[next(idx_sampler)].items():\n",
+ "for k, v in bs_valid[next(idx_sampler)].items():\n",
" print(f\"{k:<15} {str(v.shape):<35} {v.dtype};\")"
]
},
@@ -868,21 +958,25 @@
"val_options.deviceIterations(bs_valid.batches_per_step)\n",
"val_options.outputMode(poptorch.OutputMode.All)\n",
"\n",
- "valid_dl = bs_valid.get_dataloader(options=val_options, shuffle=False, num_workers=3, persistent_workers=True)\n",
+ "valid_dl = bs_valid.get_dataloader(\n",
+ " options=val_options, shuffle=False, num_workers=3, persistent_workers=True\n",
+ ")\n",
"\n",
"# Each triple is now to be scored against a specific set of negatives, so we turn off negative sample sharing\n",
"transe_score_fn.negative_sample_sharing = False\n",
"\n",
"evaluation = Evaluation([\"mrr\", \"hits@1\", \"hits@5\", \"hits@10\"], reduction=\"sum\")\n",
- "model_inf = ScoreMovingBessKGE(negative_sampler=ns_valid, score_fn=transe_score_fn, evaluation=evaluation)\n",
+ "model_inf = ScoreMovingBessKGE(\n",
+ " negative_sampler=ns_valid, score_fn=transe_score_fn, evaluation=evaluation\n",
+ ")\n",
"\n",
"poptorch_model_inf = poptorch.inferenceModel(model_inf, options=val_options)\n",
"\n",
"poptorch_model_inf.entity_embedding.replicaGrouping(\n",
- " poptorch.CommGroupType.NoGrouping,\n",
- " 0,\n",
- " poptorch.VariableRetrievalMode.OnePerGroup,\n",
- " )\n",
+ " poptorch.CommGroupType.NoGrouping,\n",
+ " 0,\n",
+ " poptorch.VariableRetrievalMode.OnePerGroup,\n",
+ ")\n",
"\n",
"# Compile model\n",
"batch = next(iter(valid_dl))\n",
@@ -914,10 +1008,15 @@
" res = poptorch_model_inf(**{k: v.flatten(end_dim=1) for k, v in batch_val.items()})\n",
" n_val_queries += batch_val[\"triple_mask\"].sum()\n",
" # By transposing res[\"metrics\"] we separate the outputs for the different metrics\n",
- " val_log.append({k: v.sum() for k, v in zip(\n",
- " evaluation.metrics.keys(),\n",
- " res[\"metrics\"].T,\n",
- " )})\n",
+ " val_log.append(\n",
+ " {\n",
+ " k: v.sum()\n",
+ " for k, v in zip(\n",
+ " evaluation.metrics.keys(),\n",
+ " res[\"metrics\"].T,\n",
+ " )\n",
+ " }\n",
+ " )\n",
"\n",
"for metric in val_log[0].keys():\n",
" reduced_metric = sum([l[metric] for l in val_log]) / n_val_queries\n",