From 4c6f8e6bdbb27b821971244600814b5da540b39b Mon Sep 17 00:00:00 2001 From: mjlaali Date: Sun, 16 Jun 2024 11:59:29 -0700 Subject: [PATCH 01/15] [Feature] Add TensorDict storage --- tensordict/nn/storage.py | 340 +++++++++++++++++++++++++++++++++++++++ test/test_storage.py | 153 ++++++++++++++++++ 2 files changed, 493 insertions(+) create mode 100644 tensordict/nn/storage.py create mode 100644 test/test_storage.py diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py new file mode 100644 index 000000000..134a491a2 --- /dev/null +++ b/tensordict/nn/storage.py @@ -0,0 +1,340 @@ +import abc +from typing import Callable, Dict, Generic, List, Optional, TypeVar + +import torch + +import torch.nn as nn + +from tensordict import NestedKey, TensorDict, TensorDictBase +from tensordict.nn.common import TensorDictModuleBase + +K = TypeVar("K") +V = TypeVar("V") + + +class TensorStorage(abc.ABC, Generic[K, V]): + """An Abstraction for implementing different storage. + + This class is for internal use, please use derived classes instead. + """ + + def clear(self) -> None: + raise NotImplementedError + + def __getitem__(self, item: K) -> V: + raise NotImplementedError + + def __setitem__(self, key: K, value: V) -> None: + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError + + def contain(self, item: K) -> torch.Tensor: + raise NotImplementedError + + +class DynamicStorage(TensorStorage[torch.Tensor, torch.Tensor]): + """A Dynamic Tensor Storage. + + This is a storage that save its tensors in cpu memories. It + expands as necessary. + """ + + def __init__(self, default_tensor: torch.Tensor): + self.tensor_dict: Dict[int, torch.Tensor] = {} + self.default_tensor = default_tensor + + def clear(self) -> None: + self.tensor_dict.clear() + + def __getitem__(self, indexes: torch.Tensor) -> torch.Tensor: + values: List[torch.Tensor] = [] + for index in torch.unbind(indexes): + value = self.tensor_dict.get(index.item()) + if value is None: + value = self.default_tensor.clone() + values.append(value) + + return torch.stack(values) + + def __setitem__(self, indexes: torch.Tensor, values: torch.Tensor) -> None: + for index, value in zip(torch.unbind(indexes), torch.unbind(values)): + self.tensor_dict[index.item()] = value + + def __len__(self) -> None: + return len(self.tensor_dict) + + def contain(self, indexes: torch.Tensor) -> torch.Tensor: + res: List[bool] = [] + for index in torch.unbind(indexes): + res.append(index.item() in self.tensor_dict) + + return torch.Tensor(res).to(torch.int64) + + +class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): + """A Fixed Tensor Storage. + + This is storage that backed by nn.Embedding and hence can be in any device that + nn.Embedding supports. The size of memory is fixed and cannot be extended. + """ + + def __init__( + self, embedding: nn.Embedding, init_fm: Callable[[torch.Tensor], torch.Tensor] + ): + super().__init__() + self.embedding = embedding + self.num_embedding = embedding.num_embeddings + self.flag = None + self.init_fm = init_fm + self.clear() + + def clear(self): + self.init_fm(self.embedding.weight) + self.flag = torch.zeros(size=(self.embedding.num_embeddings, 1)).to(torch.int64) + + def to_index(self, item: torch.Tensor) -> torch.Tensor: + return torch.remainder(item.to(torch.int64), self.num_embedding).to(torch.int64) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.embedding(self.to_index(x)) + + def __getitem__(self, item: torch.Tensor) -> torch.Tensor: + return self.forward(item) + + def __setitem__(self, item: torch.Tensor, value: torch.Tensor) -> None: + if value.shape[-1] != self.embedding.embedding_dim: + raise ValueError( + "The shape value does not match with storage cell shape, " + f"expected {self.embedding.embedding_dim} but got {value.shape[-1]}!" + ) + index = self.to_index(item) + with torch.no_grad(): + self.embedding.weight[index, :] = value + self.flag[index] = 1 + + def __len__(self) -> int: + return torch.sum(self.flag).item() + + def contain(self, item: torch.Tensor) -> torch.Tensor: + index = self.to_index(item) + return self.flag[index] + + +class BinaryToDecimal(torch.nn.Module): + """A Module to convert binaries encoded tensors to decimals. + + This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to + its decimal value (e.g. `9`) + """ + + def __init__( + self, + num_bits: int, + device: torch.device, + dtype: torch.dtype, + convert_to_binary: bool, + ): + super().__init__() + self.convert_to_binary = convert_to_binary + self.bases = 2 ** torch.arange(num_bits - 1, -1, -1).to(device, dtype) + self.num_bits = num_bits + self.zero_tensor = torch.zeros((1,)) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + num_features = features.shape[-1] + if self.num_bits > num_features: + raise ValueError(f"{num_features=} is less than {self.num_bits=}") + elif num_features % self.num_bits != 0: + raise ValueError(f"{num_features=} is not divisible by {self.num_bits=}") + + binary_features = ( + torch.heaviside(features, self.zero_tensor) + if self.convert_to_binary + else features + ) + feature_parts = binary_features.reshape(shape=(-1, self.num_bits)) + digits = torch.sum(self.bases * feature_parts, -1) + digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits)) + aggregated_digits = torch.sum(digits, dim=-1) + return aggregated_digits + + +class SipHash(torch.nn.Module): + """A Module to Compute SipHash values for given tensors. + + A hash function module based on SipHash implementation in python. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hash_values = [] + for x_i in torch.unbind(x): + hash_value = hash(x_i.detach().numpy().tobytes()) + hash_values.append(hash_value) + + return torch.Tensor(hash_values).to(torch.int64).unsqueeze(dim=-1) + + +class QueryModule(TensorDictModuleBase): + """A Module to generate compatible indexes for storage. + + A module that queries a storage and return required index of that storage. + Currently, it only outputs integer indexes (torch.int64). + """ + + def __init__( + self, + in_keys: List[NestedKey], + index_key: NestedKey, + hash_module: torch.nn.Module, + aggregation_module: Optional[torch.nn.Module] = None, + ): + self.in_keys = in_keys if isinstance(in_keys, List) else [in_keys] + self.out_keys = [index_key] + + super().__init__() + + self.aggregation_module = ( + aggregation_module if aggregation_module else hash_module + ) + + self.hash_module = hash_module + self.index_key = index_key + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + hash_values = [] + + for k in self.in_keys: + hash_values.append(self.hash_module(tensordict[k])) + + td_hash_value = self.aggregation_module( + torch.stack( + hash_values, + dim=-1, + ), + ) + + output = tensordict.clone(False) + output[self.index_key] = td_hash_value + return output + + +class TensorDictStorage( + TensorDictModuleBase, TensorStorage[TensorDictModuleBase, TensorDictModuleBase] +): + """A Storage for TensorDict. + + This module resembles a memory. It takes a tensordict as its input and + returns another tensordict as output similar to TensorDictModuleBase. However, + it provides additional functionality like python map: + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> mlp = torch.nn.Linear(in_features=1, out_features=64, bias=True) + >>> binary_to_decimal = BinaryToDecimal( + ... num_bits=8, device="cpu", dtype=torch.int32, convert_to_binary=True + ... ) + >>> query_module = QueryModule( + ... in_keys=["key1", "key2"], + ... index_key="index", + ... hash_module=torch.nn.Sequential(mlp, binary_to_decimal), + ... ) + >>> embedding_storage = FixedStorage( + ... torch.nn.Embedding(num_embeddings=23, embedding_dim=1), + ... lambda x: torch.nn.init.constant_(x, 0), + ... ) + >>> tensor_dict_storage = TensorDictStorage( + ... in_keys=["key1", "key2"], + ... query_module=query_module, + ... memories={"index": embedding_storage}, + ... ) + >>> index = TensorDict( + ... { + ... "key1": torch.Tensor([[-1], [1], [3], [-3]]), + ... "key2": torch.Tensor([[0], [2], [4], [-4]]), + ... }, + ... batch_size=(4,), + ... ) + >>> value = TensorDict( + ... {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ... ) + >>> tensor_dict_storage[index] = value + >>> assert torch.sum(tensor_dict_storage.contain(index)).item() == 4 + >>> new_index = index.clone(True) + >>> new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + >>> retrieve_value = tensor_dict_storage[new_index] + >>> assert (retrieve_value["index"] == value["index"]).all() + """ + + def __init__( + self, + query_module: QueryModule, + memories: Dict[NestedKey, TensorStorage[torch.Tensor, torch.Tensor]], + ): + self.in_keys = query_module.in_keys + self.out_keys = list(memories.keys()) + + super().__init__() + + for k in self.out_keys: + assert k in memories, f"{k} has not been assigned to a memory" + self.query_module = query_module + self.index_key = query_module.index_key + self.memories = memories + self.batch_added = False + + def clear(self) -> None: + for mem in self.memories.values(): + mem.clear() + + def to_index(self, item: TensorDictBase) -> torch.Tensor: + return self.query_module(item)[self.index_key] + + def maybe_add_batch( + self, item: TensorDictBase, value: Optional[TensorDictBase] + ) -> TensorDictBase: + self.batch_added = False + if len(item.batch_size) == 0: + self.batch_added = True + + item = item.unsqueeze(dim=0) + if value is not None: + value = value.unsqueeze(dim=0) + + return item, value + + def maybe_remove_batch(self, item: TensorDictBase) -> TensorDictBase: + if self.batch_added: + item = item.squeeze(dim=0) + return item + + def __getitem__(self, item: TensorDictBase) -> TensorDictBase: + item, _ = self.maybe_add_batch(item, None) + + index = self.to_index(item) + + res = TensorDict({}, batch_size=item.batch_size) + for k in self.out_keys: + res[k] = self.memories[k][index] + + res = self.maybe_remove_batch(res) + return res + + def __setitem__(self, item: TensorDictBase, value: TensorDictBase): + item, value = self.maybe_add_batch(item, value) + + index = self.to_index(item) + for k in self.out_keys: + self.memories[k][index] = value[k] + + def __len__(self): + return len(next(iter(self.memories.values()))) + + def contain(self, item: TensorDictBase) -> torch.Tensor: + item, _ = self.maybe_add_batch(item, None) + index = self.to_index(item) + + index = self.maybe_remove_batch(index) + return next(iter(self.memories.values())).contain(index) diff --git a/test/test_storage.py b/test/test_storage.py new file mode 100644 index 000000000..38e712a31 --- /dev/null +++ b/test/test_storage.py @@ -0,0 +1,153 @@ +import torch + +from tensordict import TensorDict +from tensordict.nn.storage import ( + BinaryToDecimal, + DynamicStorage, + FixedStorage, + QueryModule, + SipHash, + TensorDictStorage, +) + + +def test_embedding_memory(): + embedding_storage = FixedStorage( + torch.nn.Embedding(num_embeddings=10, embedding_dim=2), + lambda x: torch.nn.init.constant_(x, 0), + ) + + index = torch.Tensor([1, 2]).long() + assert len(embedding_storage) == 0 + assert not (embedding_storage[index] == torch.ones(size=(2, 2))).all() + + embedding_storage[index] = torch.ones(size=(2, 2)) + assert torch.sum(embedding_storage.contain(index)).item() == 2 + + assert (embedding_storage[index] == torch.ones(size=(2, 2))).all() + + assert len(embedding_storage) == 2 + embedding_storage.clear() + assert len(embedding_storage) == 0 + assert not (embedding_storage[index] == torch.ones(size=(2, 2))).all() + + +def test_binary_to_decimal(): + binary_to_decimal = BinaryToDecimal( + num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True + ) + binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]]) + decimal = binary_to_decimal(binary) + + assert decimal.shape == (2,) + assert (decimal == torch.Tensor([3, 2])).all() + + +def test_query(): + torch.manual_seed(3) + mlp = torch.nn.Linear(in_features=1, out_features=64, bias=True) + binary_to_decimal = BinaryToDecimal( + num_bits=8, device="cpu", dtype=torch.int32, convert_to_binary=True + ) + query_module = QueryModule( + in_keys=["key1", "key2"], + index_key="index", + hash_module=torch.nn.Sequential(mlp, binary_to_decimal), + ) + + query = TensorDict( + { + "key1": torch.Tensor([[1], [1], [1], [2]]), + "key2": torch.Tensor([[3], [3], [2], [3]]), + }, + batch_size=(4,), + ) + res = query_module(query) + assert res["index"][0] == res["index"][1] + for i in range(1, 3): + assert res["index"][i].item() != res["index"][i + 1].item(), ( + f"{i} = ({query[i]['key1']}, {query[i]['key2']}) s index and {i + 1} = ({query[i + 1]['key1']}, " + f"{query[i + 1]['key2']})'s index are the same!" + ) + + +def test_query_module(): + torch.manual_seed(5) + mlp = torch.nn.Linear(in_features=1, out_features=64, bias=True) + binary_to_decimal = BinaryToDecimal( + num_bits=8, device="cpu", dtype=torch.int32, convert_to_binary=True + ) + query_module = QueryModule( + in_keys=["key1", "key2"], + index_key="index", + hash_module=torch.nn.Sequential(mlp, binary_to_decimal), + ) + + embedding_storage = FixedStorage( + torch.nn.Embedding(num_embeddings=23, embedding_dim=1), + lambda x: torch.nn.init.constant_(x, 0), + ) + + tensor_dict_storage = TensorDictStorage( + in_keys=["key1", "key2"], + query_module=query_module, + memories={"index": embedding_storage}, + ) + + index = TensorDict( + { + "key1": torch.Tensor([[-1], [1], [3], [-3]]), + "key2": torch.Tensor([[0], [2], [4], [-4]]), + }, + batch_size=(4,), + ) + + value = TensorDict( + {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ) + + tensor_dict_storage[index] = value + assert torch.sum(tensor_dict_storage.contain(index)).item() == 4 + + new_index = index.clone(True) + new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + retrieve_value = tensor_dict_storage[new_index] + + assert (retrieve_value["index"] == value["index"]).all() + + +def test_storage(): + query_module = QueryModule( + in_keys=["key1", "key2"], + index_key="index", + hash_module=SipHash(), + ) + + embedding_storage = DynamicStorage() + + tensor_dict_storage = TensorDictStorage( + in_keys=["key1", "key2"], + query_module=query_module, + memories={"index": embedding_storage}, + ) + + index = TensorDict( + { + "key1": torch.Tensor([[-1], [1], [3], [-3]]), + "key2": torch.Tensor([[0], [2], [4], [-4]]), + }, + batch_size=(4,), + ) + + value = TensorDict( + {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ) + + tensor_dict_storage[index] = value + assert torch.sum(tensor_dict_storage.contain(index)).item() == 4 + + new_index = index.clone(True) + new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + retrieve_value = tensor_dict_storage[new_index] + + assert (retrieve_value["index"] == value["index"]).all() From 3ac4e62524912d88e7cd7c05e45652f747f29655 Mon Sep 17 00:00:00 2001 From: mjlaali Date: Sun, 23 Jun 2024 22:58:01 -0700 Subject: [PATCH 02/15] Fix unit tests. --- tensordict/nn/storage.py | 29 +++++++++++++++-------------- test/test_storage.py | 22 +++++----------------- 2 files changed, 20 insertions(+), 31 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 134a491a2..4bb553447 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -77,7 +77,7 @@ class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): """A Fixed Tensor Storage. This is storage that backed by nn.Embedding and hence can be in any device that - nn.Embedding supports. The size of memory is fixed and cannot be extended. + nn.Embedding supports. The size of storage is fixed and cannot be extended. """ def __init__( @@ -173,7 +173,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: hash_value = hash(x_i.detach().numpy().tobytes()) hash_values.append(hash_value) - return torch.Tensor(hash_values).to(torch.int64).unsqueeze(dim=-1) + return torch.Tensor(hash_values).to(torch.int64) class QueryModule(TensorDictModuleBase): @@ -225,7 +225,7 @@ class TensorDictStorage( ): """A Storage for TensorDict. - This module resembles a memory. It takes a tensordict as its input and + This module resembles a storage. It takes a tensordict as its input and returns another tensordict as output similar to TensorDictModuleBase. However, it provides additional functionality like python map: @@ -248,7 +248,7 @@ class TensorDictStorage( >>> tensor_dict_storage = TensorDictStorage( ... in_keys=["key1", "key2"], ... query_module=query_module, - ... memories={"index": embedding_storage}, + ... key_to_storage={"index": embedding_storage}, ... ) >>> index = TensorDict( ... { @@ -271,22 +271,22 @@ class TensorDictStorage( def __init__( self, query_module: QueryModule, - memories: Dict[NestedKey, TensorStorage[torch.Tensor, torch.Tensor]], + key_to_storage: Dict[NestedKey, TensorStorage[torch.Tensor, torch.Tensor]], ): self.in_keys = query_module.in_keys - self.out_keys = list(memories.keys()) + self.out_keys = list(key_to_storage.keys()) super().__init__() for k in self.out_keys: - assert k in memories, f"{k} has not been assigned to a memory" + assert k in key_to_storage, f"{k} has not been assigned to a memory" self.query_module = query_module self.index_key = query_module.index_key - self.memories = memories + self.key_to_storage = key_to_storage self.batch_added = False def clear(self) -> None: - for mem in self.memories.values(): + for mem in self.key_to_storage.values(): mem.clear() def to_index(self, item: TensorDictBase) -> torch.Tensor: @@ -317,7 +317,7 @@ def __getitem__(self, item: TensorDictBase) -> TensorDictBase: res = TensorDict({}, batch_size=item.batch_size) for k in self.out_keys: - res[k] = self.memories[k][index] + res[k] = self.key_to_storage[k][index] res = self.maybe_remove_batch(res) return res @@ -327,14 +327,15 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): index = self.to_index(item) for k in self.out_keys: - self.memories[k][index] = value[k] + self.key_to_storage[k][index] = value[k] def __len__(self): - return len(next(iter(self.memories.values()))) + return len(next(iter(self.key_to_storage.values()))) def contain(self, item: TensorDictBase) -> torch.Tensor: item, _ = self.maybe_add_batch(item, None) index = self.to_index(item) - index = self.maybe_remove_batch(index) - return next(iter(self.memories.values())).contain(index) + res = next(iter(self.key_to_storage.values())).contain(index) + res = self.maybe_remove_batch(res) + return res diff --git a/test/test_storage.py b/test/test_storage.py index 38e712a31..5f7eede1c 100644 --- a/test/test_storage.py +++ b/test/test_storage.py @@ -44,15 +44,10 @@ def test_binary_to_decimal(): def test_query(): - torch.manual_seed(3) - mlp = torch.nn.Linear(in_features=1, out_features=64, bias=True) - binary_to_decimal = BinaryToDecimal( - num_bits=8, device="cpu", dtype=torch.int32, convert_to_binary=True - ) query_module = QueryModule( in_keys=["key1", "key2"], index_key="index", - hash_module=torch.nn.Sequential(mlp, binary_to_decimal), + hash_module=SipHash(), ) query = TensorDict( @@ -72,15 +67,10 @@ def test_query(): def test_query_module(): - torch.manual_seed(5) - mlp = torch.nn.Linear(in_features=1, out_features=64, bias=True) - binary_to_decimal = BinaryToDecimal( - num_bits=8, device="cpu", dtype=torch.int32, convert_to_binary=True - ) query_module = QueryModule( in_keys=["key1", "key2"], index_key="index", - hash_module=torch.nn.Sequential(mlp, binary_to_decimal), + hash_module=SipHash(), ) embedding_storage = FixedStorage( @@ -89,9 +79,8 @@ def test_query_module(): ) tensor_dict_storage = TensorDictStorage( - in_keys=["key1", "key2"], query_module=query_module, - memories={"index": embedding_storage}, + key_to_storage={"index": embedding_storage}, ) index = TensorDict( @@ -123,12 +112,11 @@ def test_storage(): hash_module=SipHash(), ) - embedding_storage = DynamicStorage() + embedding_storage = DynamicStorage(default_tensor=torch.zeros((1,))) tensor_dict_storage = TensorDictStorage( - in_keys=["key1", "key2"], query_module=query_module, - memories={"index": embedding_storage}, + key_to_storage={"index": embedding_storage}, ) index = TensorDict( From 7c7079a5cfb484b0e190b5d339f9a948612ddac5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 15:52:03 +0100 Subject: [PATCH 03/15] amend --- tensordict/nn/storage.py | 25 +++++++++++++++---------- test/test_storage.py | 5 +++++ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 4bb553447..e7d58cead 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import abc from typing import Callable, Dict, Generic, List, Optional, TypeVar @@ -48,29 +53,29 @@ def __init__(self, default_tensor: torch.Tensor): def clear(self) -> None: self.tensor_dict.clear() - def __getitem__(self, indexes: torch.Tensor) -> torch.Tensor: + def __getitem__(self, indices: torch.Tensor) -> torch.Tensor: values: List[torch.Tensor] = [] - for index in torch.unbind(indexes): - value = self.tensor_dict.get(index.item()) + for index in indices.tolist(): + value = self.tensor_dict.get(index) if value is None: value = self.default_tensor.clone() values.append(value) return torch.stack(values) - def __setitem__(self, indexes: torch.Tensor, values: torch.Tensor) -> None: - for index, value in zip(torch.unbind(indexes), torch.unbind(values)): - self.tensor_dict[index.item()] = value + def __setitem__(self, indices: torch.Tensor, values: torch.Tensor) -> None: + for index, value in zip(indices.tolist(), values.unbind(0)): + self.tensor_dict[index] = value def __len__(self) -> None: return len(self.tensor_dict) - def contain(self, indexes: torch.Tensor) -> torch.Tensor: + def contain(self, indices: torch.Tensor) -> torch.Tensor: res: List[bool] = [] - for index in torch.unbind(indexes): - res.append(index.item() in self.tensor_dict) + for index in indices.tolist(): + res.append(index in self.tensor_dict) - return torch.Tensor(res).to(torch.int64) + return torch.tensor(res, dtype=torch.int64) class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): diff --git a/test/test_storage.py b/test/test_storage.py index 5f7eede1c..c7d1dd16b 100644 --- a/test/test_storage.py +++ b/test/test_storage.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import torch from tensordict import TensorDict From 554fb1546b61152d1ffb62fc26b8f347736b65f9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 15:56:17 +0100 Subject: [PATCH 04/15] amend --- tensordict/nn/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index e7d58cead..6bad8870b 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -97,7 +97,7 @@ def __init__( def clear(self): self.init_fm(self.embedding.weight) - self.flag = torch.zeros(size=(self.embedding.num_embeddings, 1)).to(torch.int64) + self.flag = torch.zeros((self.embedding.num_embeddings, 1), dtype=torch.int64) def to_index(self, item: torch.Tensor) -> torch.Tensor: return torch.remainder(item.to(torch.int64), self.num_embedding).to(torch.int64) From 91e62d1e9245b28e8a2cb71c1049cf05c2089624 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 15:59:24 +0100 Subject: [PATCH 05/15] amend --- tensordict/nn/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 6bad8870b..6d2d2a90c 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -143,7 +143,7 @@ def __init__( ): super().__init__() self.convert_to_binary = convert_to_binary - self.bases = 2 ** torch.arange(num_bits - 1, -1, -1).to(device, dtype) + self.bases = 2 ** torch.arange(num_bits - 1, -1, -1, device=device, dtype=dtype) self.num_bits = num_bits self.zero_tensor = torch.zeros((1,)) From 4d8438a6253d3b4ff427215dc4bb99efd2389f1e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 16:01:01 +0100 Subject: [PATCH 06/15] amend --- tensordict/nn/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 6d2d2a90c..65120a89b 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -178,7 +178,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: hash_value = hash(x_i.detach().numpy().tobytes()) hash_values.append(hash_value) - return torch.Tensor(hash_values).to(torch.int64) + return torch.tensor(hash_values, dtype=torch.int64) class QueryModule(TensorDictModuleBase): From 372b649e0ccd280ed989b358aecc405587bb4c49 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 16:03:08 +0100 Subject: [PATCH 07/15] amend --- tensordict/nn/storage.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 65120a89b..a0a2c6d5f 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -182,10 +182,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class QueryModule(TensorDictModuleBase): - """A Module to generate compatible indexes for storage. + """A Module to generate compatible indices for storage. A module that queries a storage and return required index of that storage. - Currently, it only outputs integer indexes (torch.int64). + Currently, it only outputs integer indices (torch.int64). """ def __init__( @@ -193,7 +193,7 @@ def __init__( in_keys: List[NestedKey], index_key: NestedKey, hash_module: torch.nn.Module, - aggregation_module: Optional[torch.nn.Module] = None, + aggregation_module: torch.nn.Module | None = None, ): self.in_keys = in_keys if isinstance(in_keys, List) else [in_keys] self.out_keys = [index_key] @@ -298,7 +298,7 @@ def to_index(self, item: TensorDictBase) -> torch.Tensor: return self.query_module(item)[self.index_key] def maybe_add_batch( - self, item: TensorDictBase, value: Optional[TensorDictBase] + self, item: TensorDictBase, value: TensorDictBase | None ) -> TensorDictBase: self.batch_added = False if len(item.batch_size) == 0: From 81080e0cf4e67e50793256496bca55bd6256fbc2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 16:11:07 +0100 Subject: [PATCH 08/15] amend --- tensordict/nn/storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index a0a2c6d5f..bbbbb19bc 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -237,7 +237,7 @@ class TensorDictStorage( Examples: >>> import torch >>> from tensordict import TensorDict - >>> mlp = torch.nn.Linear(in_features=1, out_features=64, bias=True) + >>> mlp = torch.nn.LazyLinear(out_features=64, bias=True) >>> binary_to_decimal = BinaryToDecimal( ... num_bits=8, device="cpu", dtype=torch.int32, convert_to_binary=True ... ) @@ -251,7 +251,6 @@ class TensorDictStorage( ... lambda x: torch.nn.init.constant_(x, 0), ... ) >>> tensor_dict_storage = TensorDictStorage( - ... in_keys=["key1", "key2"], ... query_module=query_module, ... key_to_storage={"index": embedding_storage}, ... ) @@ -266,6 +265,7 @@ class TensorDictStorage( ... {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) ... ) >>> tensor_dict_storage[index] = value + >>> tensor_dict_storage[index] >>> assert torch.sum(tensor_dict_storage.contain(index)).item() == 4 >>> new_index = index.clone(True) >>> new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) From 4a284b0e308f6b64131b549f638d25395638da3b Mon Sep 17 00:00:00 2001 From: mjlaali Date: Mon, 24 Jun 2024 21:25:36 -0700 Subject: [PATCH 09/15] Address comments --- tensordict/nn/storage.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index bbbbb19bc..9df355e06 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import abc +from abc import abstractmethod from typing import Callable, Dict, Generic, List, Optional, TypeVar import torch @@ -23,18 +24,23 @@ class TensorStorage(abc.ABC, Generic[K, V]): This class is for internal use, please use derived classes instead. """ + @abstractmethod def clear(self) -> None: raise NotImplementedError + @abstractmethod def __getitem__(self, item: K) -> V: raise NotImplementedError + @abstractmethod def __setitem__(self, key: K, value: V) -> None: raise NotImplementedError + @abstractmethod def __len__(self) -> int: raise NotImplementedError + @abstractmethod def contain(self, item: K) -> torch.Tensor: raise NotImplementedError @@ -75,7 +81,7 @@ def contain(self, indices: torch.Tensor) -> torch.Tensor: for index in indices.tolist(): res.append(index in self.tensor_dict) - return torch.tensor(res, dtype=torch.int64) + return torch.tensor(res, dtype=torch.bool) class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): @@ -97,7 +103,7 @@ def __init__( def clear(self): self.init_fm(self.embedding.weight) - self.flag = torch.zeros((self.embedding.num_embeddings, 1), dtype=torch.int64) + self.flag = torch.zeros((self.embedding.num_embeddings, 1), dtype=torch.bool) def to_index(self, item: torch.Tensor) -> torch.Tensor: return torch.remainder(item.to(torch.int64), self.num_embedding).to(torch.int64) @@ -117,7 +123,7 @@ def __setitem__(self, item: torch.Tensor, value: torch.Tensor) -> None: index = self.to_index(item) with torch.no_grad(): self.embedding.weight[index, :] = value - self.flag[index] = 1 + self.flag[index] = True def __len__(self) -> int: return torch.sum(self.flag).item() @@ -145,7 +151,7 @@ def __init__( self.convert_to_binary = convert_to_binary self.bases = 2 ** torch.arange(num_bits - 1, -1, -1, device=device, dtype=dtype) self.num_bits = num_bits - self.zero_tensor = torch.zeros((1,)) + self.zero_tensor = torch.zeros((1,), device=device) def forward(self, features: torch.Tensor) -> torch.Tensor: num_features = features.shape[-1] From 9670801cc02f489b043a48c4f4f902c9fe2872e1 Mon Sep 17 00:00:00 2001 From: mjlaali Date: Mon, 24 Jun 2024 21:50:46 -0700 Subject: [PATCH 10/15] Improve docstring --- tensordict/nn/storage.py | 80 +++++++++++++++++++++++++++++++++++----- test/test_storage.py | 21 ++++++++++- 2 files changed, 90 insertions(+), 11 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 9df355e06..525b0dfbe 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -5,7 +5,7 @@ import abc from abc import abstractmethod -from typing import Callable, Dict, Generic, List, Optional, TypeVar +from typing import Callable, Dict, Generic, List, TypeVar import torch @@ -50,6 +50,15 @@ class DynamicStorage(TensorStorage[torch.Tensor, torch.Tensor]): This is a storage that save its tensors in cpu memories. It expands as necessary. + + Examples: + >>> storage = DynamicStorage(default_tensor=torch.zeros((1,))) + >>> index = torch.randn((3,)) + >>> value = torch.rand((2, 1)) + >>> storage[index] = value + >>> assert len(storage) == 3 + >>> assert (storage[index.clone()] == value).all() + """ def __init__(self, default_tensor: torch.Tensor): @@ -89,6 +98,22 @@ class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): This is storage that backed by nn.Embedding and hence can be in any device that nn.Embedding supports. The size of storage is fixed and cannot be extended. + + Examples: + >>> embedding_storage = FixedStorage( + ... torch.nn.Embedding(num_embeddings=10, embedding_dim=2), + ... lambda x: torch.nn.init.constant_(x, 0), + ... ) + >>> index = torch.Tensor([1, 2]).long() + >>> assert len(embedding_storage) == 0 + >>> assert not (embedding_storage[index] == torch.ones(size=(2, 2))).all() + >>> embedding_storage[index] = torch.ones(size=(2, 2)) + >>> assert torch.sum(embedding_storage.contain(index)).item() == 2 + >>> assert (embedding_storage[index] == torch.ones(size=(2, 2))).all() + >>> assert len(embedding_storage) == 2 + >>> embedding_storage.clear() + >>> assert len(embedding_storage) == 0 + >>> assert not (embedding_storage[index] == torch.ones(size=(2, 2))).all() """ def __init__( @@ -138,6 +163,15 @@ class BinaryToDecimal(torch.nn.Module): This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to its decimal value (e.g. `9`) + + Examples: + >>> binary_to_decimal = BinaryToDecimal( + ... num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True + ... ) + >>> binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]]) + >>> decimal = binary_to_decimal(binary) + >>> assert decimal.shape == (2,) + >>> assert (decimal == torch.Tensor([3, 2])).all() """ def __init__( @@ -176,6 +210,15 @@ class SipHash(torch.nn.Module): """A Module to Compute SipHash values for given tensors. A hash function module based on SipHash implementation in python. + + Examples: + >>> from typing import cast + >>> a = torch.rand((3, 2)) + >>> b = a.clone() + >>> hash_module = SipHash() + >>> hash_a = cast(torch.Tensor, hash_module(a)) + >>> hash_b = cast(torch.Tensor, hash_module(b)) + >>> assert (hash_a == hash_b).all() """ def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -192,6 +235,27 @@ class QueryModule(TensorDictModuleBase): A module that queries a storage and return required index of that storage. Currently, it only outputs integer indices (torch.int64). + + Examples: + >>> query_module = QueryModule( + ... in_keys=["key1", "key2"], + ... index_key="index", + ... hash_module=SipHash(), + ... ) + >>> query = TensorDict( + ... { + ... "key1": torch.Tensor([[1], [1], [1], [2]]), + ... "key2": torch.Tensor([[3], [3], [2], [3]]), + ... }, + ... batch_size=(4,), + ... ) + >>> res = query_module(query) + >>> assert res["index"][0] == res["index"][1] + >>> for i in range(1, 3): + >>> assert res["index"][i].item() != res["index"][i + 1].item(), ( + ... f"{i} = ({query[i]['key1']}, {query[i]['key2']}) s index and {i + 1} = ({query[i + 1]['key1']}, " + ... f"{query[i + 1]['key2']})'s index are the same!" + ... ) """ def __init__( @@ -243,18 +307,14 @@ class TensorDictStorage( Examples: >>> import torch >>> from tensordict import TensorDict - >>> mlp = torch.nn.LazyLinear(out_features=64, bias=True) - >>> binary_to_decimal = BinaryToDecimal( - ... num_bits=8, device="cpu", dtype=torch.int32, convert_to_binary=True - ... ) + >>> from typing import cast >>> query_module = QueryModule( ... in_keys=["key1", "key2"], ... index_key="index", - ... hash_module=torch.nn.Sequential(mlp, binary_to_decimal), + ... hash_module=SipHash(), ... ) - >>> embedding_storage = FixedStorage( - ... torch.nn.Embedding(num_embeddings=23, embedding_dim=1), - ... lambda x: torch.nn.init.constant_(x, 0), + >>> embedding_storage = DynamicStorage( + ... default_tensor=torch.zeros((1,)), ... ) >>> tensor_dict_storage = TensorDictStorage( ... query_module=query_module, @@ -276,7 +336,7 @@ class TensorDictStorage( >>> new_index = index.clone(True) >>> new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) >>> retrieve_value = tensor_dict_storage[new_index] - >>> assert (retrieve_value["index"] == value["index"]).all() + >>> assert cast(torch.Tensor, retrieve_value["index"] == value["index"]).all() """ def __init__( diff --git a/test/test_storage.py b/test/test_storage.py index c7d1dd16b..e68fced38 100644 --- a/test/test_storage.py +++ b/test/test_storage.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import cast import torch @@ -37,6 +38,15 @@ def test_embedding_memory(): assert not (embedding_storage[index] == torch.ones(size=(2, 2))).all() +def test_dynamic_storage(): + storage = DynamicStorage(default_tensor=torch.zeros((1,))) + index = torch.randn((3,)) + value = torch.rand((3, 1)) + storage[index] = value + assert len(storage) == 3 + assert (storage[index.clone()] == value).all() + + def test_binary_to_decimal(): binary_to_decimal = BinaryToDecimal( num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True @@ -48,6 +58,15 @@ def test_binary_to_decimal(): assert (decimal == torch.Tensor([3, 2])).all() +def test_sip_hash(): + a = torch.rand((3, 2)) + b = a.clone() + hash_module = SipHash() + hash_a = cast(torch.Tensor, hash_module(a)) + hash_b = cast(torch.Tensor, hash_module(b)) + assert (hash_a == hash_b).all() + + def test_query(): query_module = QueryModule( in_keys=["key1", "key2"], @@ -143,4 +162,4 @@ def test_storage(): new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) retrieve_value = tensor_dict_storage[new_index] - assert (retrieve_value["index"] == value["index"]).all() + assert cast(torch.Tensor, retrieve_value["index"] == value["index"]).all() From 0f31f7e3ef0b91d9b0f979d935d87517a00bf125 Mon Sep 17 00:00:00 2001 From: mjlaali Date: Tue, 25 Jun 2024 08:24:10 -0700 Subject: [PATCH 11/15] Rename private methods such that they start with _ --- tensordict/nn/storage.py | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 525b0dfbe..b4877fb40 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -68,7 +68,14 @@ def __init__(self, default_tensor: torch.Tensor): def clear(self) -> None: self.tensor_dict.clear() + def _check_indices(self, indices: torch.Tensor) -> None: + if len(indices.shape) != 1: + raise ValueError( + f"Indices have to be a one-d vector but got {indices.shape}" + ) + def __getitem__(self, indices: torch.Tensor) -> torch.Tensor: + self._check_indices(indices) values: List[torch.Tensor] = [] for index in indices.tolist(): value = self.tensor_dict.get(index) @@ -79,6 +86,7 @@ def __getitem__(self, indices: torch.Tensor) -> torch.Tensor: return torch.stack(values) def __setitem__(self, indices: torch.Tensor, values: torch.Tensor) -> None: + self._check_indices(indices) for index, value in zip(indices.tolist(), values.unbind(0)): self.tensor_dict[index] = value @@ -86,6 +94,7 @@ def __len__(self) -> None: return len(self.tensor_dict) def contain(self, indices: torch.Tensor) -> torch.Tensor: + self._check_indices(indices) res: List[bool] = [] for index in indices.tolist(): res.append(index in self.tensor_dict) @@ -130,11 +139,11 @@ def clear(self): self.init_fm(self.embedding.weight) self.flag = torch.zeros((self.embedding.num_embeddings, 1), dtype=torch.bool) - def to_index(self, item: torch.Tensor) -> torch.Tensor: + def _to_index(self, item: torch.Tensor) -> torch.Tensor: return torch.remainder(item.to(torch.int64), self.num_embedding).to(torch.int64) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.embedding(self.to_index(x)) + return self.embedding(self._to_index(x)) def __getitem__(self, item: torch.Tensor) -> torch.Tensor: return self.forward(item) @@ -145,7 +154,7 @@ def __setitem__(self, item: torch.Tensor, value: torch.Tensor) -> None: "The shape value does not match with storage cell shape, " f"expected {self.embedding.embedding_dim} but got {value.shape[-1]}!" ) - index = self.to_index(item) + index = self._to_index(item) with torch.no_grad(): self.embedding.weight[index, :] = value self.flag[index] = True @@ -154,7 +163,7 @@ def __len__(self) -> int: return torch.sum(self.flag).item() def contain(self, item: torch.Tensor) -> torch.Tensor: - index = self.to_index(item) + index = self._to_index(item) return self.flag[index] @@ -360,10 +369,10 @@ def clear(self) -> None: for mem in self.key_to_storage.values(): mem.clear() - def to_index(self, item: TensorDictBase) -> torch.Tensor: + def _to_index(self, item: TensorDictBase) -> torch.Tensor: return self.query_module(item)[self.index_key] - def maybe_add_batch( + def _maybe_add_batch( self, item: TensorDictBase, value: TensorDictBase | None ) -> TensorDictBase: self.batch_added = False @@ -376,27 +385,27 @@ def maybe_add_batch( return item, value - def maybe_remove_batch(self, item: TensorDictBase) -> TensorDictBase: + def _maybe_remove_batch(self, item: TensorDictBase) -> TensorDictBase: if self.batch_added: item = item.squeeze(dim=0) return item def __getitem__(self, item: TensorDictBase) -> TensorDictBase: - item, _ = self.maybe_add_batch(item, None) + item, _ = self._maybe_add_batch(item, None) - index = self.to_index(item) + index = self._to_index(item) res = TensorDict({}, batch_size=item.batch_size) for k in self.out_keys: res[k] = self.key_to_storage[k][index] - res = self.maybe_remove_batch(res) + res = self._maybe_remove_batch(res) return res def __setitem__(self, item: TensorDictBase, value: TensorDictBase): - item, value = self.maybe_add_batch(item, value) + item, value = self._maybe_add_batch(item, value) - index = self.to_index(item) + index = self._to_index(item) for k in self.out_keys: self.key_to_storage[k][index] = value[k] @@ -404,9 +413,9 @@ def __len__(self): return len(next(iter(self.key_to_storage.values()))) def contain(self, item: TensorDictBase) -> torch.Tensor: - item, _ = self.maybe_add_batch(item, None) - index = self.to_index(item) + item, _ = self._maybe_add_batch(item, None) + index = self._to_index(item) res = next(iter(self.key_to_storage.values())).contain(index) - res = self.maybe_remove_batch(res) + res = self._maybe_remove_batch(res) return res From da61af843f56504aa2b99e61d44bd6f1f5e068f5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 26 Jun 2024 13:38:22 +0100 Subject: [PATCH 12/15] amend --- tensordict/nn/storage.py | 163 ++++++++++++++++++++++++++++++--------- test/test_storage.py | 6 +- 2 files changed, 130 insertions(+), 39 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index b4877fb40..8e2788426 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -41,19 +41,33 @@ def __len__(self) -> int: raise NotImplementedError @abstractmethod - def contain(self, item: K) -> torch.Tensor: + def contains(self, item: K) -> torch.Tensor: raise NotImplementedError + def __contains__(self, item): + return self.contains(item) + class DynamicStorage(TensorStorage[torch.Tensor, torch.Tensor]): """A Dynamic Tensor Storage. + Indices can be of any pytorch dtype. + This is a storage that save its tensors in cpu memories. It expands as necessary. + It is assumed that all values in the storage can be stacked together + using :func:`~torch.stack`. + + Args: + default_tensor (torch.Tensor): the default value to return when + an index cannot be found. This value will not be set in the + storage. + Examples: >>> storage = DynamicStorage(default_tensor=torch.zeros((1,))) >>> index = torch.randn((3,)) + >>> # set a value with a mismatching shape: it will be expanded to (3, 2, 1) shape >>> value = torch.rand((2, 1)) >>> storage[index] = value >>> assert len(storage) == 3 @@ -78,22 +92,27 @@ def __getitem__(self, indices: torch.Tensor) -> torch.Tensor: self._check_indices(indices) values: List[torch.Tensor] = [] for index in indices.tolist(): - value = self.tensor_dict.get(index) - if value is None: - value = self.default_tensor.clone() + value = self.tensor_dict.get(index, self.default_tensor) values.append(value) return torch.stack(values) def __setitem__(self, indices: torch.Tensor, values: torch.Tensor) -> None: self._check_indices(indices) + if not indices.ndim: + self.tensor_dict[indices.item()] = values + return + if not values.ndim: + values = values.expand(indices.shape[0]) + if values.shape[0] != indices.shape[0]: + values = values.expand(indices.shape[0], *values.shape) for index, value in zip(indices.tolist(), values.unbind(0)): self.tensor_dict[index] = value def __len__(self) -> None: return len(self.tensor_dict) - def contain(self, indices: torch.Tensor) -> torch.Tensor: + def contains(self, indices: torch.Tensor) -> torch.Tensor: self._check_indices(indices) res: List[bool] = [] for index in indices.tolist(): @@ -105,19 +124,27 @@ def contain(self, indices: torch.Tensor) -> torch.Tensor: class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): """A Fixed Tensor Storage. + Indices must be of ``torch.long`` dtype. + This is storage that backed by nn.Embedding and hence can be in any device that nn.Embedding supports. The size of storage is fixed and cannot be extended. + Args: + embedding (torch.nn.Embedding): the embedding module, or equivalent. + init_fn (Callable[[torch.Tensor], torch.Tensor], optional): an init function + for the embedding weights. Defaults to + :func:`~torch.nn.init.normal_`, like `nn.Embedding`. + Examples: >>> embedding_storage = FixedStorage( ... torch.nn.Embedding(num_embeddings=10, embedding_dim=2), ... lambda x: torch.nn.init.constant_(x, 0), ... ) - >>> index = torch.Tensor([1, 2]).long() + >>> index = torch.Tensor([1, 2], dtype=torch.long) >>> assert len(embedding_storage) == 0 >>> assert not (embedding_storage[index] == torch.ones(size=(2, 2))).all() >>> embedding_storage[index] = torch.ones(size=(2, 2)) - >>> assert torch.sum(embedding_storage.contain(index)).item() == 2 + >>> assert torch.sum(embedding_storage.contains(index)).item() == 2 >>> assert (embedding_storage[index] == torch.ones(size=(2, 2))).all() >>> assert len(embedding_storage) == 2 >>> embedding_storage.clear() @@ -126,12 +153,14 @@ class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): """ def __init__( - self, embedding: nn.Embedding, init_fm: Callable[[torch.Tensor], torch.Tensor] + self, embedding: nn.Embedding, init_fm: Callable[[torch.Tensor], torch.Tensor]|None=None ): super().__init__() self.embedding = embedding self.num_embedding = embedding.num_embeddings self.flag = None + if init_fm is None: + init_fm = torch.nn.init.normal_ self.init_fm = init_fm self.clear() @@ -162,7 +191,7 @@ def __setitem__(self, item: torch.Tensor, value: torch.Tensor) -> None: def __len__(self) -> int: return torch.sum(self.flag).item() - def contain(self, item: torch.Tensor) -> torch.Tensor: + def contains(self, item: torch.Tensor) -> torch.Tensor: index = self._to_index(item) return self.flag[index] @@ -173,6 +202,18 @@ class BinaryToDecimal(torch.nn.Module): This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to its decimal value (e.g. `9`) + Args: + num_bits (int): the number of bits to use for the bases table. + The number of bits must be lower or equal to the input length and the input length + must be divisible by ``num_bits``. If ``num_bits`` is lower than the number of + bits in the input, the end result will be aggregated on the last dimension using + :func:`~torch.sum`. + device (torch.device): the device where inputs and outputs are to be expected. + dtype (torch.dtype): the output dtype. + convert_to_binary (bool, optional): if ``True``, the input to the ``forward`` + method will be cast to a binary input using :func:`~torch.heavyside`. + Defaults to ``False``. + Examples: >>> binary_to_decimal = BinaryToDecimal( ... num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True @@ -188,7 +229,7 @@ def __init__( num_bits: int, device: torch.device, dtype: torch.dtype, - convert_to_binary: bool, + convert_to_binary: bool=False, ): super().__init__() self.convert_to_binary = convert_to_binary @@ -209,7 +250,7 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: else features ) feature_parts = binary_features.reshape(shape=(-1, self.num_bits)) - digits = torch.sum(self.bases * feature_parts, -1) + digits = torch.vmap(torch.dot, (None, 0))(self.bases, feature_parts.to(self.bases.dtype)) digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits)) aggregated_digits = torch.sum(digits, dim=-1) return aggregated_digits @@ -220,20 +261,27 @@ class SipHash(torch.nn.Module): A hash function module based on SipHash implementation in python. + .. warning:: This module relies on the builtin ``hash`` function. + To get reproducible results across runs, the ``PYTHONHASHSEED`` environment + variable must be set before the code is run (changing this value during code + execution is without effect). + Examples: - >>> from typing import cast - >>> a = torch.rand((3, 2)) + >>> # Assuming we set PYTHONHASHSEED=0 prior to running this code + >>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) >>> b = a.clone() >>> hash_module = SipHash() - >>> hash_a = cast(torch.Tensor, hash_module(a)) - >>> hash_b = cast(torch.Tensor, hash_module(b)) + >>> hash_a = hash_module(a) + >>> hash_a + tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521]) + >>> hash_b = hash_module(b) >>> assert (hash_a == hash_b).all() """ def forward(self, x: torch.Tensor) -> torch.Tensor: hash_values = [] - for x_i in torch.unbind(x): - hash_value = hash(x_i.detach().numpy().tobytes()) + for x_i in x.detach().cpu().view(torch.uint8).numpy(): + hash_value = hash(x_i.tobytes()) hash_values.append(hash_value) return torch.tensor(hash_values, dtype=torch.int64) @@ -245,6 +293,20 @@ class QueryModule(TensorDictModuleBase): A module that queries a storage and return required index of that storage. Currently, it only outputs integer indices (torch.int64). + Args: + in_keys (list of NestedKeys): keys of the input tensordict that + will be used to generate the hash value. + index_key (NestedKey): the output key where the hash value will be written. + + Keyword Args: + hash_module (nn.Module or Callable[[torch.Tensor], torch.Tensor]): a hash + module similar to :class:`~tensordict.nn.SipHash` (default). + aggregation_module (torch.nn.Module or Callable[[torch.Tensor], torch.Tensor]): a + method to aggregate the hash values. Defaults to the value of ``hash_module``. + If only one ``in_Keys`` is provided, this module will be ignored. + clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be + returned. Defaults to ``False``. + Examples: >>> query_module = QueryModule( ... in_keys=["key1", "key2"], @@ -255,52 +317,68 @@ class QueryModule(TensorDictModuleBase): ... { ... "key1": torch.Tensor([[1], [1], [1], [2]]), ... "key2": torch.Tensor([[3], [3], [2], [3]]), + ... "other": torch.randn(4), ... }, ... batch_size=(4,), ... ) >>> res = query_module(query) + >>> # The first two pairs of key1 and key2 match >>> assert res["index"][0] == res["index"][1] - >>> for i in range(1, 3): - >>> assert res["index"][i].item() != res["index"][i + 1].item(), ( - ... f"{i} = ({query[i]['key1']}, {query[i]['key2']}) s index and {i + 1} = ({query[i + 1]['key1']}, " - ... f"{query[i + 1]['key2']})'s index are the same!" - ... ) + >>> # The last three pairs of key1 and key2 have at least one mismatching value + >>> assert res["index"][1] != res["index"][2] + >>> assert res["index"][2] != res["index"][3] """ def __init__( self, in_keys: List[NestedKey], index_key: NestedKey, - hash_module: torch.nn.Module, + *, + hash_module: torch.nn.Module | None = None, aggregation_module: torch.nn.Module | None = None, + clone: bool = False, ): self.in_keys = in_keys if isinstance(in_keys, List) else [in_keys] + if len(in_keys) == 0: + raise ValueError(f"`in_keys` cannot be empty.") self.out_keys = [index_key] super().__init__() + if hash_module is None: + hash_module = SipHash() + self.aggregation_module = ( aggregation_module if aggregation_module else hash_module ) self.hash_module = hash_module self.index_key = index_key + self.clone = clone def forward(self, tensordict: TensorDictBase) -> TensorDictBase: hash_values = [] - for k in self.in_keys: - hash_values.append(self.hash_module(tensordict[k])) + i = -1 # to make linter happy + for i, k in enumerate(self.in_keys): + hash_values.append(self.hash_module(tensordict.get(k))) - td_hash_value = self.aggregation_module( - torch.stack( - hash_values, - dim=-1, - ), - ) + if i > 0: + td_hash_value = self.aggregation_module( + torch.stack( + hash_values, + dim=-1, + ), + ) + else: + td_hash_value = hash_values[0] - output = tensordict.clone(False) - output[self.index_key] = td_hash_value + if self.clone: + output = tensordict.copy() + else: + output = tensordict + + output.set(self.index_key, td_hash_value) return output @@ -313,6 +391,13 @@ class TensorDictStorage( returns another tensordict as output similar to TensorDictModuleBase. However, it provides additional functionality like python map: + Args: + query_module (TensorDictModuleBase): a query module, typically an instance of + :class:`~tensordict.nn.QueryModule`, used to map a set of tensordict + entries to a hash key. + key_to_storage (Dict[NestedKey, TensorStorage[torch.Tensor, torch.Tensor]]): + a dictionary representing the map from an index key to a tensor storage. + Examples: >>> import torch >>> from tensordict import TensorDict @@ -341,7 +426,13 @@ class TensorDictStorage( ... ) >>> tensor_dict_storage[index] = value >>> tensor_dict_storage[index] - >>> assert torch.sum(tensor_dict_storage.contain(index)).item() == 4 + TensorDict( + fields={ + index: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 >>> new_index = index.clone(True) >>> new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) >>> retrieve_value = tensor_dict_storage[new_index] @@ -412,10 +503,10 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): def __len__(self): return len(next(iter(self.key_to_storage.values()))) - def contain(self, item: TensorDictBase) -> torch.Tensor: + def contains(self, item: TensorDictBase) -> torch.Tensor: item, _ = self._maybe_add_batch(item, None) index = self._to_index(item) - res = next(iter(self.key_to_storage.values())).contain(index) + res = next(iter(self.key_to_storage.values())).contains(index) res = self._maybe_remove_batch(res) return res diff --git a/test/test_storage.py b/test/test_storage.py index e68fced38..a90ca3379 100644 --- a/test/test_storage.py +++ b/test/test_storage.py @@ -28,7 +28,7 @@ def test_embedding_memory(): assert not (embedding_storage[index] == torch.ones(size=(2, 2))).all() embedding_storage[index] = torch.ones(size=(2, 2)) - assert torch.sum(embedding_storage.contain(index)).item() == 2 + assert torch.sum(embedding_storage.contains(index)).item() == 2 assert (embedding_storage[index] == torch.ones(size=(2, 2))).all() @@ -120,7 +120,7 @@ def test_query_module(): ) tensor_dict_storage[index] = value - assert torch.sum(tensor_dict_storage.contain(index)).item() == 4 + assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 new_index = index.clone(True) new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) @@ -156,7 +156,7 @@ def test_storage(): ) tensor_dict_storage[index] = value - assert torch.sum(tensor_dict_storage.contain(index)).item() == 4 + assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 new_index = index.clone(True) new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) From cf901ead602c3e8ac89af19c72b9d45418530484 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 5 Jul 2024 07:43:10 +0100 Subject: [PATCH 13/15] amend --- tensordict/nn/storage.py | 57 +++++++++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 8e2788426..b4ffe3d2c 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -153,7 +153,9 @@ class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): """ def __init__( - self, embedding: nn.Embedding, init_fm: Callable[[torch.Tensor], torch.Tensor]|None=None + self, + embedding: nn.Embedding, + init_fm: Callable[[torch.Tensor], torch.Tensor] | None = None, ): super().__init__() self.embedding = embedding @@ -229,7 +231,7 @@ def __init__( num_bits: int, device: torch.device, dtype: torch.dtype, - convert_to_binary: bool=False, + convert_to_binary: bool = False, ): super().__init__() self.convert_to_binary = convert_to_binary @@ -250,7 +252,9 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: else features ) feature_parts = binary_features.reshape(shape=(-1, self.num_bits)) - digits = torch.vmap(torch.dot, (None, 0))(self.bases, feature_parts.to(self.bases.dtype)) + digits = torch.vmap(torch.dot, (None, 0))( + self.bases, feature_parts.to(self.bases.dtype) + ) digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits)) aggregated_digits = torch.sum(digits, dim=-1) return aggregated_digits @@ -280,13 +284,49 @@ class SipHash(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: hash_values = [] - for x_i in x.detach().cpu().view(torch.uint8).numpy(): + for x_i in x.detach().cpu().numpy(): hash_value = hash(x_i.tobytes()) hash_values.append(hash_value) return torch.tensor(hash_values, dtype=torch.int64) +class RandomProjectionHash(SipHash): + """A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through SipHash. + + This module requires sklearn to be installed. + + """ + + def __init__( + self, + n_components=16, + projection_type: str = "gaussian", + dtype_cast=torch.float16, + **kwargs, + ): + super().__init__() + from sklearn.random_projection import ( + GaussianRandomProjection, + SparseRandomProjection, + ) + + self.dtype_cast = dtype_cast + if projection_type == "gaussian": + self.transform = GaussianRandomProjection( + n_components=n_components, **kwargs + ) + elif projection_type == "sparse_random": + self.transform = SparseRandomProjection(n_components=n_components, **kwargs) + else: + raise ValueError + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.transform.transform(x) + x = torch.as_tensor(x, dtype=self.dtype_cast) + return super().forward(x) + + class QueryModule(TensorDictModuleBase): """A Module to generate compatible indices for storage. @@ -340,7 +380,7 @@ def __init__( ): self.in_keys = in_keys if isinstance(in_keys, List) else [in_keys] if len(in_keys) == 0: - raise ValueError(f"`in_keys` cannot be empty.") + raise ValueError("`in_keys` cannot be empty.") self.out_keys = [index_key] super().__init__() @@ -360,7 +400,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: hash_values = [] i = -1 # to make linter happy - for i, k in enumerate(self.in_keys): + for k in self.in_keys: hash_values.append(self.hash_module(tensordict.get(k))) if i > 0: @@ -449,8 +489,6 @@ def __init__( super().__init__() - for k in self.out_keys: - assert k in key_to_storage, f"{k} has not been assigned to a memory" self.query_module = query_module self.index_key = query_module.index_key self.key_to_storage = key_to_storage @@ -488,7 +526,8 @@ def __getitem__(self, item: TensorDictBase) -> TensorDictBase: res = TensorDict({}, batch_size=item.batch_size) for k in self.out_keys: - res[k] = self.key_to_storage[k][index] + storage: FixedStorage = self.key_to_storage[k] + res[k] = storage[index] res = self._maybe_remove_batch(res) return res From 92e13b7bf1d91225bf8b637e355624315133b930 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 5 Jul 2024 14:16:30 +0100 Subject: [PATCH 14/15] amend --- tensordict/nn/storage.py | 145 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 139 insertions(+), 6 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index b4ffe3d2c..7d0eb43ca 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -121,6 +121,33 @@ def contains(self, indices: torch.Tensor) -> torch.Tensor: return torch.tensor(res, dtype=torch.bool) +class LazyDynamicStorage(DynamicStorage): + """A lazy version of DynamicStorage where the default tensor is assumed to be zeros_like(init_tensor). + + See :class:`~tensordict.nn.storage.DynamicStorage` for more info. + + """ + + def __init__(self, default_tensor: torch.Tensor | None = None) -> None: + if default_tensor is None: + self._init = False + default_tensor = torch.nn.UninitializedBuffer() + else: + self._init = False + super().__init__(default_tensor) + + def __setitem__(self, indices: torch.Tensor, values: torch.Tensor) -> None: + if not self._init: + if len(indices): + val = values[0] + else: + val = values + self.default_tensor.materialize( + shape=val.shape, device=val.device, dtype=val.dtype + ) + return super().__setitem__(indices, values) + + class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): """A Fixed Tensor Storage. @@ -296,32 +323,65 @@ class RandomProjectionHash(SipHash): This module requires sklearn to be installed. + Keyword Args: + n_components (int, optional): the low-dimensional number of components of the projections. + Defaults to 16. + projection_type (str, optional): the projection type to use. + Must be one of ``"gaussian"`` or ``"sparse_random"``. Defaults to "gaussian". + dtype_cast (torch.dtype, optional): the dtype to cast the projection to. + Defaults to ``torch.float16``. + lazy (bool, optional): if ``True``, the random projection is fit on the first batch of data + received. Defaults to ``False``. + """ + _N_COMPONENTS_DEFAULT = 16 + def __init__( self, - n_components=16, - projection_type: str = "gaussian", + *, + n_components: int | None = None, + projection_type: str = "sparse_random", dtype_cast=torch.float16, + lazy: bool = False, **kwargs, ): + if n_components is None: + n_components = self._N_COMPONENTS_DEFAULT + super().__init__() from sklearn.random_projection import ( GaussianRandomProjection, SparseRandomProjection, ) + self.lazy = lazy + self._init = not lazy + self.dtype_cast = dtype_cast - if projection_type == "gaussian": + if projection_type.lower() == "gaussian": self.transform = GaussianRandomProjection( n_components=n_components, **kwargs ) - elif projection_type == "sparse_random": + elif projection_type.lower() in ("sparse_random", "sparse-random"): self.transform = SparseRandomProjection(n_components=n_components, **kwargs) else: - raise ValueError + raise ValueError( + f"Only 'gaussian' and 'sparse_random' projections are supported. Got projection_type={projection_type}." + ) + + def fit(self, x): + """Fits the random projection to the input data.""" + self.transform.fit(x) + self._init = True def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.lazy and not self._init: + self.fit(x) + elif not self._init: + raise RuntimeError( + f"The {type(self).__name__} has not been initialized. Call fit before calling this method." + ) x = self.transform.transform(x) x = torch.as_tensor(x, dtype=self.dtype_cast) return super().forward(x) @@ -337,6 +397,7 @@ class QueryModule(TensorDictModuleBase): in_keys (list of NestedKeys): keys of the input tensordict that will be used to generate the hash value. index_key (NestedKey): the output key where the hash value will be written. + Defaults to ``"_index"``. Keyword Args: hash_module (nn.Module or Callable[[torch.Tensor], torch.Tensor]): a hash @@ -372,7 +433,7 @@ class QueryModule(TensorDictModuleBase): def __init__( self, in_keys: List[NestedKey], - index_key: NestedKey, + index_key: NestedKey = "_index", *, hash_module: torch.nn.Module | None = None, aggregation_module: torch.nn.Module | None = None, @@ -481,6 +542,7 @@ class TensorDictStorage( def __init__( self, + *, query_module: QueryModule, key_to_storage: Dict[NestedKey, TensorStorage[torch.Tensor, torch.Tensor]], ): @@ -494,6 +556,77 @@ def __init__( self.key_to_storage = key_to_storage self.batch_added = False + @classmethod + def from_tensordict_pair( + cls, + source, + dest, + in_keys: List[NestedKey], + out_keys: List[NestedKey] | None = None, + storage_type: type = LazyDynamicStorage, + hash_module: Callable | None = None, + ): + """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. + + Args: + source (TensorDict): An example of source tensordict, used as index in the storage. + dest (TensorDict): An example of dest tensordict, used as data in the storage. + in_keys (List[NestedKey]): a list of keys to use in the map. + out_keys (List[NestedKey]): a list of keys to return in the output tensordict. + All keys absent from out_keys, even if present in ``dest``, will not be stored + in the storage. Defaults to ``None`` (all keys are registered). + storage_type (type, optional): a type of tensor storage. + Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`. + Other options include :class:`~tensordict.nn.storage.FixedStorage`. + hash_module (Callable, optional): a hash function to use in the :class:`~tensordict.nn.storage.QueryModule`. + Defaults to :class:`SipHash` for low-dimensional inputs, and :class:`~tensordict.nn.storage.RandomProjectionHash` + for larger inputs. + + Examples: + >>> # The following example requires torchrl and gymnasium to be installed + >>> from tensordict.nn.storage import TensorDictStorage, RandomProjectionHash + >>> from torchrl.envs import GymEnv + >>> env = GymEnv("CartPole-v1") + >>> rollout = env.rollout(100) + >>> source, dest = r.exclude("next"), r.get("next") + >>> storage = TensorDictStorage.from_tensordict_pair( + ... source, dest, + ... in_keys=["observation", "action"], + ... ) + >>> # maps the (obs, action) tuple to a corresponding next state + >>> storage[source] = dest + >>> storage[source] + TensorDict( + fields={ + done: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([35, 4]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([35]), + device=None, + is_shared=False) + + """ + # Build query module + if hash_module is None: + # Count the features, if they're greater than RandomProjectionHash._N_COMPONENTS_DEFAULT + # use that module to project them to that dimensionality. + n_feat = 0 + for in_key in in_keys: + n_feat += source[in_key].shape[-1] + if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT: + hash_module = RandomProjectionHash() + query_module = QueryModule(in_keys, hash_module=hash_module) + + # Build key_to_storage + if out_keys is None: + out_keys = list(dest.keys(True, True)) + key_to_storage = {} + for key in out_keys: + key_to_storage[key] = storage_type() + return cls(query_module=query_module, key_to_storage=key_to_storage) + def clear(self) -> None: for mem in self.key_to_storage.values(): mem.clear() From c3bf5acbc2b7a9b27ee42b201de87834d4b82c70 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 18:52:14 +0100 Subject: [PATCH 15/15] amend --- tensordict/nn/storage.py | 60 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index 7d0eb43ca..f1c8ead52 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -179,6 +179,8 @@ class FixedStorage(nn.Module, TensorStorage[torch.Tensor, torch.Tensor]): >>> assert not (embedding_storage[index] == torch.ones(size=(2, 2))).all() """ + _initialized: bool + def __init__( self, embedding: nn.Embedding, @@ -186,7 +188,11 @@ def __init__( ): super().__init__() self.embedding = embedding - self.num_embedding = embedding.num_embeddings + if not hasattr(self, "num_embeddings"): + self.num_embeddings = embedding.num_embeddings + self._initialized = True + else: + self._initialized = False self.flag = None if init_fm is None: init_fm = torch.nn.init.normal_ @@ -194,19 +200,36 @@ def __init__( self.clear() def clear(self): - self.init_fm(self.embedding.weight) - self.flag = torch.zeros((self.embedding.num_embeddings, 1), dtype=torch.bool) + if self._initialized: + self.init_fm(self.embedding.weight) + self.flag = torch.zeros( + (self.embedding.num_embeddings, 1), dtype=torch.bool + ) + self._index_to_index = {} def _to_index(self, item: torch.Tensor) -> torch.Tensor: - return torch.remainder(item.to(torch.int64), self.num_embedding).to(torch.int64) + result = [] + for _item in item.tolist(): + result.append( + self._index_to_index.setdefault(_item, len(self._index_to_index)) + ) + return torch.tensor(result) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.embedding(self._to_index(x)) + def _init(self, value): + ... + def __getitem__(self, item: torch.Tensor) -> torch.Tensor: + if not self._initialized: + raise RuntimeError(f"The module is not initialized yet.") return self.forward(item) def __setitem__(self, item: torch.Tensor, value: torch.Tensor) -> None: + if not self._initialized: + self._init(value) + if value.shape[-1] != self.embedding.embedding_dim: raise ValueError( "The shape value does not match with storage cell shape, " @@ -220,11 +243,36 @@ def __setitem__(self, item: torch.Tensor, value: torch.Tensor) -> None: def __len__(self) -> int: return torch.sum(self.flag).item() - def contains(self, item: torch.Tensor) -> torch.Tensor: - index = self._to_index(item) + def contains(self, value: torch.Tensor) -> torch.Tensor: + index = self._to_index(value) return self.flag[index] +class LazyFixedStorage(FixedStorage): + """A lazy version of FixedStorage.""" + # We don't really care about using UnintializedParams as these are not learnable params + embedding_constructor: type | Callable = torch.nn.Embedding + + def __init__( + self, + num_embeddings: int, + init_fm: Callable[[torch.Tensor], torch.Tensor] | None = None, + ) -> None: + self.num_embeddings = num_embeddings + self.flag = None + if init_fm is None: + init_fm = torch.nn.init.normal_ + super().__init__(embedding=None, init_fm=init_fm) + + def _init(self, val): + embedding_dim = val.shape[-1] + self.embedding = self.embedding_constructor( + embedding_dim=embedding_dim, num_embeddings=self.num_embeddings + ).to(val.dtype) + self._initialized = True + self.clear() + + class BinaryToDecimal(torch.nn.Module): """A Module to convert binaries encoded tensors to decimals.