Skip to content

Commit

Permalink
[BE] Check ordering and exclusivity of tensorclass registers
Browse files Browse the repository at this point in the history
ghstack-source-id: becd6b07c03eccaab2733e604b3dfb21ec05ebb6
Pull Request resolved: #1176
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent 02ab260 commit 1001c18
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def test_tensorclass_stub_methods():
]

if missing_methods:
raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}")
raise Exception(
f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}"
)
Expand Down Expand Up @@ -150,6 +149,32 @@ class X:
)


def test_sorted_methods():
from tensordict.tensorclass import (
_FALLBACK_METHOD_FROM_TD,
_FALLBACK_METHOD_FROM_TD_FORCE,
_FALLBACK_METHOD_FROM_TD_NOWRAP,
_METHOD_FROM_TD,
)

lists_to_check = [
_FALLBACK_METHOD_FROM_TD_NOWRAP,
_METHOD_FROM_TD,
_FALLBACK_METHOD_FROM_TD_FORCE,
_FALLBACK_METHOD_FROM_TD,
]
# Check that each list is sorted and has unique elements
for lst in lists_to_check:
assert lst == sorted(lst), f"List {lst} is not sorted"
assert len(lst) == len(set(lst)), f"List {lst} has duplicate elements"
# Check that no two lists share any elements
for i, lst1 in enumerate(lists_to_check):
for j, lst2 in enumerate(lists_to_check):
if i != j:
shared_elements = set(lst1) & set(lst2)
assert (
not shared_elements
), f"Lists {lst1} and {lst2} share elements: {shared_elements}"


def _make_data(shape):
Expand Down

0 comments on commit 1001c18

Please sign in to comment.