-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Allowing for auto-nested tensordict #119
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good, thanks for working on it!
I've left a few comments inline.
The other main thing that we should add are some tests. We can add a couple of tests for specific functionality such as checking for the presence of "Auto-nested"
in the repr, but as @vmoens suggested in #106 it would be nice to run all tests on a tensordict with autonesting and check that we haven't inadvertently broken something.
Check this class that gets used in the tests, you can add a method for auto_nested_td
or similar, and then add it to the tests here. In principle it should work for anything that nested_td
works for, but in any case it will be interesting to see what breaks!
I added "Auto-nested" as you recommended, but because of the nature of auto-nesting I think we can't test it using existing test cases. For instance we have this test: def test_items_values_keys(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
keys = list(td.keys())
values = list(td.values())
items = list(td.items())
# Test td.items()
constructed_td1 = TensorDict({}, batch_size=td.shape)
for key, value in items:
constructed_td1.set(key, value)
> assert (td == constructed_td1).all() if we use this td: def auto_nested_td(self, device):
tensordict = TensorDict({
"a": torch.randn(4, 3, 2, 1, 5),
"b": torch.randn(4, 3, 2, 1, 10),
},
batch_size=[4, 3, 2, 1],
device=device,)
tensordict["self"] = tensordict
return tensordict when we can't construct |
After #121, It became easier to solve this auto-nesting bug. I updated pull request. But it still has a problem to add tests which I mentioned here #119 (comment) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the tests, I'm happy with tests being independent of the others although I wonder if the fact that a call to __eq__
is prohibited should not be addressed by this PR.
TBH I did not quite get what the issue was. Can you elaborate a bit more?
You can also code it and push the code with the tests breaking, that would give me some visibility on what's happening.
tensordict/tensordict.py
Outdated
@@ -1830,10 +1844,14 @@ def flatten_keys( | |||
inner_tensordict = self.get(key).flatten_keys( | |||
separator=separator, inplace=inplace | |||
) | |||
for inner_key, inner_item in inner_tensordict.items(): | |||
tensordict_out.set(separator.join([key, inner_key]), inner_item) | |||
if inner_tensordict is not self.get(key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is compatible with lazy tensordicts (eg LazyStackedTensorDict
) where two calls to get(key)
return items which ids are different (eg because they are the results of a call to stack
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's interesting point. I didn't research 'LazyStackedTensorDict' 's behaviour. But if 'get(key)' return new ids every time how can we detect self loop?
This issue was about infinite recursion during print, flatten_keys, and list(keys()). Issue # #106. This is example of code how to verify this PR: Codetensordict = TensorDict({
"key 1": torch.ones(3, 4, 5),
"key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
}, batch_size=[3, 4])
tensordict2 = TensorDict({
"super key 1": torch.ones(3, 4, 6),
"super key 2": torch.zeros(3, 4, 7, dtype=torch.bool),
}, batch_size=[3, 4])
tensordict["innerTD"] = tensordict2
tensordict2["self 2"] = tensordict2
tensordict["self"] = tensordict
tensordict["key 3"] = torch.zeros(3, 4, 7, dtype=torch.bool)
f_keys = tensordict.flatten_keys() # Ok
print(f_keys)
td_list = list(tensordict.keys(include_nested=True)) # Ok
print(td_list)
print(tensordict) # Ok |
Thanks |
@vmoens I've added tests for keys() and repr(). Unfortunately, I didn't handle to add tests for @pytest.mark.parametrize(
"td_name",
[...,
auto_nested_td,
]
) because there are many other tests which will be broken. |
We can break the tests, I'd rather have that and solve the bug. Otherwise it's hard to see what the problem is. |
I've added auto_nested_td into tests |
Some benchmarking first:
On main
So we're good on that side. Functional:
Other benchmarks seem to be roughly similar. |
@Zooll I'm on it. Thanks for your hard work! I tried to solve the I'll keep you posted! |
Solving infinite recursion
Closing as this has been addressed in #201 |
Description
Blocked self nested behaviour during repr or any other recursive calls.
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
[BUG] Auto-nested tensordict bugs #106
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Bug fix (non-breaking change which fixes an issue)
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!