Skip to content

Commit

Permalink
chore: tensordict>=0.6.0 (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov authored Oct 23, 2024
1 parent 5e30821 commit f4213e7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[project]
name = "rbyte"
version = "0.5.1"
version = "0.5.2"
description = "Multimodal dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
dependencies = [
"tensordict @ git+https://github.com/pytorch/tensordict.git@85b6b81",
"torch>=2.4.1",
"tensordict>=0.6.0",
"torch",
"polars>=1.10.0",
"pydantic>=2.9.2",
"more-itertools>=10.5.0",
Expand Down
13 changes: 8 additions & 5 deletions src/rbyte/batch/batch.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from jaxtyping import Int
from tensordict import NonTensorData, TensorDict, tensorclass
from tensordict import (
NonTensorData,
TensorDict,
tensorclass, # pyright: ignore[reportUnknownVariableType]
)
from torch import Tensor


@tensorclass # pyright: ignore[reportUntypedClassDecorator, reportArgumentType, reportCallIssue]
@tensorclass(autocast=True) # pyright: ignore[reportUntypedClassDecorator]
class BatchMeta:
sample_idx: Int[Tensor, "b 1"] # pyright: ignore[reportUninitializedInstanceVariable]
sample_idx: Tensor # pyright: ignore[reportUninitializedInstanceVariable]
input_id: NonTensorData # pyright: ignore[reportUninitializedInstanceVariable]


@tensorclass # pyright: ignore[reportUntypedClassDecorator, reportArgumentType, reportCallIssue]
@tensorclass(autocast=True) # pyright: ignore[reportUntypedClassDecorator]
class Batch:
meta: BatchMeta # pyright: ignore[reportUninitializedInstanceVariable]
frame: TensorDict # pyright: ignore[reportUninitializedInstanceVariable]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def test_mimicgen() -> None:
"sample_idx": Tensor(shape=[c.B]),
**meta_rest,
},
**batch_rest,
} if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any((
batch_rest,
frame_rest,
table_rest,
meta_rest,
Expand Down Expand Up @@ -96,7 +98,9 @@ def test_nuscenes() -> None:
"sample_idx": Tensor(shape=[c.B]),
**meta_rest,
},
**batch_rest,
} if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any((
batch_rest,
frame_rest,
table_rest,
meta_rest,
Expand Down

0 comments on commit f4213e7

Please sign in to comment.