diff --git a/pyproject.toml b/pyproject.toml index 49735f2..b81a354 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,12 @@ [project] name = "rbyte" -version = "0.5.1" +version = "0.5.2" description = "Multimodal dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] dependencies = [ - "tensordict @ git+https://github.com/pytorch/tensordict.git@85b6b81", - "torch>=2.4.1", + "tensordict>=0.6.0", + "torch", "polars>=1.10.0", "pydantic>=2.9.2", "more-itertools>=10.5.0", diff --git a/src/rbyte/batch/batch.py b/src/rbyte/batch/batch.py index 9592708..8065461 100644 --- a/src/rbyte/batch/batch.py +++ b/src/rbyte/batch/batch.py @@ -1,15 +1,18 @@ -from jaxtyping import Int -from tensordict import NonTensorData, TensorDict, tensorclass +from tensordict import ( + NonTensorData, + TensorDict, + tensorclass, # pyright: ignore[reportUnknownVariableType] +) from torch import Tensor -@tensorclass # pyright: ignore[reportUntypedClassDecorator, reportArgumentType, reportCallIssue] +@tensorclass(autocast=True) # pyright: ignore[reportUntypedClassDecorator] class BatchMeta: - sample_idx: Int[Tensor, "b 1"] # pyright: ignore[reportUninitializedInstanceVariable] + sample_idx: Tensor # pyright: ignore[reportUninitializedInstanceVariable] input_id: NonTensorData # pyright: ignore[reportUninitializedInstanceVariable] -@tensorclass # pyright: ignore[reportUntypedClassDecorator, reportArgumentType, reportCallIssue] +@tensorclass(autocast=True) # pyright: ignore[reportUntypedClassDecorator] class Batch: meta: BatchMeta # pyright: ignore[reportUninitializedInstanceVariable] frame: TensorDict # pyright: ignore[reportUninitializedInstanceVariable] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index f54b1d4..f4ebe27 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -41,7 +41,9 @@ def test_mimicgen() -> None: "sample_idx": Tensor(shape=[c.B]), **meta_rest, }, + **batch_rest, } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( + batch_rest, frame_rest, table_rest, meta_rest, @@ -96,7 +98,9 @@ def test_nuscenes() -> None: "sample_idx": Tensor(shape=[c.B]), **meta_rest, }, + **batch_rest, } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( + batch_rest, frame_rest, table_rest, meta_rest,