diff --git a/.github/unittest/linux/scripts/install.sh b/.github/unittest/linux/scripts/install.sh index 65904ffe1..07557287a 100755 --- a/.github/unittest/linux/scripts/install.sh +++ b/.github/unittest/linux/scripts/install.sh @@ -47,7 +47,10 @@ printf "* Installing tensordict\n" python setup.py develop # install torchsnapshot nightly -python -m pip install git+https://github.com/pytorch/torchsnapshot --no-build-isolation - +if [[ "$TORCH_VERSION" == "nightly" ]]; then + python -m pip install git+https://github.com/pytorch/torchsnapshot --no-build-isolation +elif [[ "$TORCH_VERSION" == "stable" ]]; then + python -m pip install torchsnapshot +fi # smoke test python -c "import functorch;import torchsnapshot" diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 7cfbe4105..cda4b72d8 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2283,8 +2283,8 @@ def test_as_tensor(self, td_name, device): ): td.as_tensor() else: - with pytest.raises(AttributeError): - td.as_tensor() + # checks that it runs + td.as_tensor() def test_items_values_keys(self, td_name, device): torch.manual_seed(1)