diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73cd447..bcaac53 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ --- repos: - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.22 + rev: v0.23 hooks: - id: validate-pyproject @@ -11,14 +11,14 @@ repos: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.2 + rev: v0.7.4 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.21.0 + rev: 1.21.1 hooks: - id: basedpyright diff --git a/config/_templates/dataset/carla.yaml b/config/_templates/dataset/carla.yaml index afd107d..5fbdb3d 100644 --- a/config/_templates/dataset/carla.yaml +++ b/config/_templates/dataset/carla.yaml @@ -108,9 +108,6 @@ inputs: #@ end sample_builder: - _target_: rbyte.sample.GreedySampleBuilder + _target_: rbyte.RollingWindowSampleBuilder index_column: _idx_ - length: 1 - stride: 1 - min_step: 1 - filter: !!null + period: 1i diff --git a/config/_templates/dataset/mimicgen.yaml b/config/_templates/dataset/mimicgen.yaml index 10cc0b4..fffc463 100644 --- a/config/_templates/dataset/mimicgen.yaml +++ b/config/_templates/dataset/mimicgen.yaml @@ -49,9 +49,6 @@ inputs: #@ end sample_builder: - _target_: rbyte.sample.GreedySampleBuilder + _target_: rbyte.RollingWindowSampleBuilder index_column: _idx_ - length: 1 - stride: 1 - min_step: 1 - filter: !!null + period: 1i diff --git a/config/_templates/dataset/nuscenes/mcap.yaml b/config/_templates/dataset/nuscenes/mcap.yaml index 4c15504..3097f15 100644 --- a/config/_templates/dataset/nuscenes/mcap.yaml +++ b/config/_templates/dataset/nuscenes/mcap.yaml @@ -49,11 +49,10 @@ inputs: fields: #@ for topic in camera_topics.values(): (@=topic@): + _idx_: log_time: _target_: polars.Datetime time_unit: ns - - _idx_: #@ end /odom: @@ -94,9 +93,6 @@ inputs: #@ end sample_builder: - _target_: rbyte.sample.GreedySampleBuilder + _target_: rbyte.RollingWindowSampleBuilder index_column: mcap/(@=camera_topics.values()[0]@)/_idx_ - length: 1 - stride: 1 - min_step: 1 - filter: !!null + period: 1i diff --git a/config/_templates/dataset/nuscenes/rrd.yaml b/config/_templates/dataset/nuscenes/rrd.yaml index 50a5dd5..c824b6e 100644 --- a/config/_templates/dataset/nuscenes/rrd.yaml +++ b/config/_templates/dataset/nuscenes/rrd.yaml @@ -84,10 +84,6 @@ inputs: #@ end sample_builder: - _target_: rbyte.sample.GreedySampleBuilder + _target_: rbyte.RollingWindowSampleBuilder index_column: rrd/(@=camera_entities.values()[0]@)/_idx_ - length: 1 - stride: 1 - min_step: 1 - filter: !!null - + period: 1i diff --git a/config/_templates/dataset/yaak.yaml b/config/_templates/dataset/yaak.yaml index ce1acbe..7cffab1 100644 --- a/config/_templates/dataset/yaak.yaml +++ b/config/_templates/dataset/yaak.yaml @@ -42,7 +42,7 @@ inputs: _target_: polars.Datetime time_unit: ns - frame_idx: polars.UInt32 + frame_idx: polars.Int32 camera_name: _target_: polars.Enum categories: @@ -130,10 +130,10 @@ inputs: #@ end sample_builder: - _target_: rbyte.sample.GreedySampleBuilder + _target_: rbyte.FixedWindowSampleBuilder index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx - length: 1 - stride: 1 - min_step: 1 + every: 6i + period: 6i filter: | - array_mean(`meta/VehicleMotion/speed`) > 40 + array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6 + and array_mean(`meta/VehicleMotion/speed`) > 40 diff --git a/config/_templates/dataset/zod.yaml b/config/_templates/dataset/zod.yaml index b1d050c..9e56288 100644 --- a/config/_templates/dataset/zod.yaml +++ b/config/_templates/dataset/zod.yaml @@ -9,7 +9,7 @@ inputs: index_column: camera_front_blur/timestamp source: _target_: rbyte.io.PathTensorSource - path: "${data_dir}/zod/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" + path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" decoder: _target_: simplejpeg.decode_jpeg _partial_: true @@ -21,7 +21,7 @@ inputs: index_column: lidar_velodyne/timestamp source: _target_: rbyte.io.NumpyTensorSource - path: "${data_dir}/zod/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy" + path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy" select: ["x", "y", "z"] table_builder: @@ -29,7 +29,7 @@ inputs: _convert_: all readers: camera_front_blur: - path: "${data_dir}/zod/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" + path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" reader: _target_: rbyte.io.PathTableReader _recursive_: false @@ -39,7 +39,7 @@ inputs: time_unit: ns lidar_velodyne: - path: "${data_dir}/zod/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy" + path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy" reader: _target_: rbyte.io.PathTableReader _recursive_: false @@ -49,7 +49,7 @@ inputs: time_unit: ns vehicle_data: - path: "${data_dir}/zod/sequences/000002_short/vehicle_data.hdf5" + path: "${data_dir}/sequences/000002_short/vehicle_data.hdf5" reader: _target_: rbyte.io.Hdf5TableReader _recursive_: false @@ -82,7 +82,7 @@ inputs: timestamp: method: asof strategy: nearest - tolerance: 50ms + tolerance: 100ms vehicle_data: ego_vehicle_controls: @@ -91,17 +91,17 @@ inputs: timestamp/nanoseconds/value: method: asof strategy: nearest - tolerance: 50ms + tolerance: 100ms acceleration_pedal/ratio/unitless/value: method: asof strategy: nearest - tolerance: 50ms + tolerance: 100ms steering_wheel_angle/angle/radians/value: method: asof strategy: nearest - tolerance: 50ms + tolerance: 100ms satellite: key: timestamp/nanoseconds/value @@ -110,5 +110,6 @@ inputs: method: interp sample_builder: - _target_: rbyte.sample.GreedySampleBuilder - length: 1 + _target_: rbyte.FixedWindowSampleBuilder + index_column: camera_front_blur/timestamp + every: 300ms diff --git a/justfile b/justfile index faf88e1..5efc81b 100644 --- a/justfile +++ b/justfile @@ -47,10 +47,10 @@ generate-config: --strict test *ARGS: generate-config - uv run pytest --capture=no {{ ARGS }} + uv run --all-extras pytest --capture=no {{ ARGS }} notebook FILE *ARGS: sync generate-config - uv run --with=jupyter,jupyterlab-vim,rerun-notebook jupyter lab {{ FILE }} {{ ARGS }} + uv run --all-extras --with=jupyter,jupyterlab-vim,rerun-notebook jupyter lab {{ FILE }} {{ ARGS }} [group('scripts')] visualize *ARGS: generate-config diff --git a/pyproject.toml b/pyproject.toml index 41a5158..26c1b9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.7.0" +version = "0.8.0" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] diff --git a/src/rbyte/__init__.py b/src/rbyte/__init__.py index fd81618..73f4243 100644 --- a/src/rbyte/__init__.py +++ b/src/rbyte/__init__.py @@ -1,7 +1,13 @@ from importlib.metadata import version from .dataset import Dataset +from .sample import FixedWindowSampleBuilder, RollingWindowSampleBuilder __version__ = version(__package__ or __name__) -__all__ = ["Dataset", "__version__"] +__all__ = [ + "Dataset", + "FixedWindowSampleBuilder", + "RollingWindowSampleBuilder", + "__version__", +] diff --git a/src/rbyte/dataset.py b/src/rbyte/dataset.py index ad90d21..97c106b 100644 --- a/src/rbyte/dataset.py +++ b/src/rbyte/dataset.py @@ -59,15 +59,15 @@ def __init__( super().__init__() _sample_builder = sample_builder.instantiate() - samples: Mapping[str, pl.LazyFrame] = {} + samples: Mapping[str, pl.DataFrame] = {} for input_id, input_cfg in inputs.items(): with bound_contextvars(input_id=input_id): - table = input_cfg.table_builder.instantiate().build().lazy() + table = input_cfg.table_builder.instantiate().build() samples[input_id] = _sample_builder.build(table) logger.debug( "built samples", - rows=table.select(pl.len()).collect().item(), - samples=samples[input_id].select(pl.len()).collect().item(), + rows=table.select(pl.len()).item(), + samples=samples[input_id].select(pl.len()).item(), ) input_id_enum = pl.Enum(sorted(samples)) @@ -85,12 +85,11 @@ def __init__( ) .sort(Column.input_id) .with_row_index(Column.sample_idx) - .collect() .rechunk() ) self._sources: pl.DataFrame = ( - pl.LazyFrame( + pl.DataFrame( [ { Column.input_id: input_id, @@ -112,7 +111,6 @@ def __init__( .explode(k) .unnest(k) .select(Column.input_id, pl.exclude(Column.input_id).name.prefix(f"{k}.")) - .collect() .rechunk() ) diff --git a/src/rbyte/sample/__init__.py b/src/rbyte/sample/__init__.py index d60e7c1..7f83bb6 100644 --- a/src/rbyte/sample/__init__.py +++ b/src/rbyte/sample/__init__.py @@ -1,3 +1,4 @@ -from .greedy_builder import GreedySampleBuilder +from .fixed_window import FixedWindowSampleBuilder +from .rolling_window import RollingWindowSampleBuilder -__all__ = ["GreedySampleBuilder"] +__all__ = ["FixedWindowSampleBuilder", "RollingWindowSampleBuilder"] diff --git a/src/rbyte/sample/base.py b/src/rbyte/sample/base.py index 67d3e07..eb44648 100644 --- a/src/rbyte/sample/base.py +++ b/src/rbyte/sample/base.py @@ -5,4 +5,4 @@ @runtime_checkable class SampleBuilder(Protocol): - def build(self, source: pl.LazyFrame) -> pl.LazyFrame: ... + def build(self, source: pl.DataFrame) -> pl.DataFrame: ... diff --git a/src/rbyte/sample/fixed_window.py b/src/rbyte/sample/fixed_window.py new file mode 100644 index 0000000..f82c3fc --- /dev/null +++ b/src/rbyte/sample/fixed_window.py @@ -0,0 +1,54 @@ +from datetime import timedelta +from typing import Literal, override +from uuid import uuid4 + +import polars as pl +from polars._typing import ClosedInterval +from pydantic import validate_call + +from .base import SampleBuilder + + +class FixedWindowSampleBuilder(SampleBuilder): + """ + Build samples using fixed (potentially overlapping) windows based on a temporal or + integer column. + + https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by_dynamic + """ + + @validate_call + def __init__( + self, + *, + index_column: str, + every: str | timedelta, + period: str | timedelta | None = None, + closed: ClosedInterval = "left", + filter: str | None = None, # noqa: A002 + ) -> None: + self._index_column: pl.Expr = pl.col(index_column) + self._every: str | timedelta = every + self._period: str | timedelta | None = period + self._closed: ClosedInterval = closed + self._filter: str | Literal[True] = filter if filter is not None else True + + @override + def build(self, source: pl.DataFrame) -> pl.DataFrame: + return ( + source.sort(self._index_column) + .with_columns(self._index_column.alias(_index_column := uuid4().hex)) + .group_by_dynamic( + index_column=_index_column, + every=self._every, + period=self._period, + closed=self._closed, + label="datapoint", + start_by="datapoint", + ) + .agg(pl.all()) + .sql(f"select * from self where ({self._filter})") # noqa: S608 + .filter(self._index_column.list.len() > 0) + .sort(_index_column) + .drop(_index_column) + ) diff --git a/src/rbyte/sample/greedy_builder.py b/src/rbyte/sample/greedy_builder.py deleted file mode 100644 index eb4cc32..0000000 --- a/src/rbyte/sample/greedy_builder.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Annotated, override -from uuid import uuid4 - -import polars as pl -from pydantic import PositiveInt, StringConstraints, validate_call - -from .base import SampleBuilder - - -class GreedySampleBuilder(SampleBuilder): - @validate_call - def __init__( - self, - index_column: str | None = None, - length: PositiveInt = 1, - min_step: PositiveInt = 1, - stride: PositiveInt = 1, - filter: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)] # noqa: A002 - | None = None, - ) -> None: - super().__init__() - - self._index_column: str | None = index_column - self._length: int = length - self._min_step: int = min_step - self._stride: int = stride - self._filter: str | None = filter - - @override - def build(self, source: pl.LazyFrame) -> pl.LazyFrame: - if (idx_col := self._index_column) is None: - idx_col = uuid4().hex - source = source.with_row_index(idx_col) - - idx_dtype = source.select(idx_col).collect_schema()[idx_col] - - return ( - source.select( - pl.int_range( - pl.col(idx_col).min().fill_null(value=0), - pl.col(idx_col).max().fill_null(value=0) + 1, - step=self._min_step, - dtype=idx_dtype, # pyright: ignore[reportArgumentType] - ) - ) - .select( - pl.int_ranges( - pl.col(idx_col), - pl.col(idx_col) + self._length * self._stride, - self._stride, - dtype=idx_dtype, # pyright: ignore[reportArgumentType] - ) - ) - .with_row_index(sample_idx_col := uuid4().hex) - .explode(idx_col) - .join(source, on=idx_col, how="inner") - .group_by(sample_idx_col) - .all() - .filter(pl.col(idx_col).list.len() == self._length) - .sql(f"select * from self where ({self._filter or True})") # noqa: S608 - .sort(sample_idx_col) - .drop([sample_idx_col, *([idx_col] if self._index_column is None else [])]) - .select(pl.all().list.to_array(self._length)) - ) diff --git a/src/rbyte/sample/rolling_window.py b/src/rbyte/sample/rolling_window.py new file mode 100644 index 0000000..cdfee09 --- /dev/null +++ b/src/rbyte/sample/rolling_window.py @@ -0,0 +1,51 @@ +from datetime import timedelta +from typing import Literal, override +from uuid import uuid4 + +import polars as pl +from polars._typing import ClosedInterval +from pydantic import validate_call + +from .base import SampleBuilder + + +class RollingWindowSampleBuilder(SampleBuilder): + """ + Build samples using rolling windows based on a temporal or integer column. + + https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.rolling + """ + + @validate_call + def __init__( + self, + *, + index_column: str, + period: str | timedelta, + offset: str | timedelta | None = None, + closed: ClosedInterval = "right", + filter: str | None = None, # noqa: A002 + ) -> None: + self._index_column: pl.Expr = pl.col(index_column) + self._period: str | timedelta = period + self._offset: str | timedelta | None = offset + self._closed: ClosedInterval = closed + self._filter: str | Literal[True] = filter if filter is not None else True + + @override + def build(self, source: pl.DataFrame) -> pl.DataFrame: + return ( + source.sort(self._index_column) + .with_columns(self._index_column.alias(_index_column := uuid4().hex)) + .rolling( + index_column=_index_column, + period=self._period, + offset=self._offset, + closed=self._closed, + ) + .agg(pl.all()) + .sql(f"select * from self where ({self._filter})") # noqa: S608 + .filter(self._index_column.list.len() > 0) + .sort(_index_column) + .drop(_index_column) + ) diff --git a/tests/data/.gitattributes b/tests/data/.gitattributes new file mode 100644 index 0000000..9a2f202 --- /dev/null +++ b/tests/data/.gitattributes @@ -0,0 +1,8 @@ +*.hdf5 filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.log filter=lfs diff=lfs merge=lfs -text +*.mcap filter=lfs diff=lfs merge=lfs -text +*.md filter= diff= merge= text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.rrd filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/mimicgen/.gitattributes b/tests/data/mimicgen/.gitattributes deleted file mode 100644 index 0820b3d..0000000 --- a/tests/data/mimicgen/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -*.hdf5 filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/nuscenes/mcap/.gitattributes b/tests/data/nuscenes/mcap/.gitattributes deleted file mode 100644 index 91ccd51..0000000 --- a/tests/data/nuscenes/mcap/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -*.mcap filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/nuscenes/rrd/.gitattributes b/tests/data/nuscenes/rrd/.gitattributes deleted file mode 100644 index ddfc47e..0000000 --- a/tests/data/nuscenes/rrd/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -*.rrd filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/yaak/.gitattributes b/tests/data/yaak/.gitattributes deleted file mode 100644 index dc5353e..0000000 --- a/tests/data/yaak/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -Niro098-HQ filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/zod/.gitattributes b/tests/data/zod/.gitattributes deleted file mode 100644 index 9d570c1..0000000 --- a/tests/data/zod/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -sequences filter=lfs diff=lfs merge=lfs -text diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 6875172..545ce48 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -25,17 +25,15 @@ def test_mimicgen() -> None: dataloader = instantiate(cfg.dataloader) - c = SimpleNamespace( - B=cfg.dataloader.batch_size, S=cfg.dataloader.dataset.sample_builder.length - ) + c = SimpleNamespace(B=cfg.dataloader.batch_size) batch = next(iter(dataloader)) match batch.to_dict(): case { "data": { - "obs/agentview_image": Tensor(shape=[c.B, c.S, *_]), - "_idx_": Tensor(shape=[c.B, c.S]), - "obs/robot0_eef_pos": Tensor(shape=[c.B, c.S, *_]), + "obs/agentview_image": Tensor(shape=[c.B, _, *_]), + "_idx_": Tensor(shape=[c.B, _]), + "obs/robot0_eef_pos": Tensor(shape=[c.B, _, *_]), **data_rest, }, "meta": { @@ -73,28 +71,26 @@ def test_nuscenes_mcap() -> None: dataloader = instantiate(cfg.dataloader) - c = SimpleNamespace( - B=cfg.dataloader.batch_size, S=cfg.dataloader.dataset.sample_builder.length - ) + c = SimpleNamespace(B=cfg.dataloader.batch_size) batch = next(iter(dataloader)) match batch.to_dict(): case { "data": { - "CAM_FRONT": Tensor(shape=[c.B, c.S, *_]), - "CAM_FRONT_LEFT": Tensor(shape=[c.B, c.S, *_]), - "CAM_FRONT_RIGHT": Tensor(shape=[c.B, c.S, *_]), - "mcap//CAM_FRONT/image_rect_compressed/_idx_": Tensor(shape=[c.B, c.S]), + "CAM_FRONT": Tensor(shape=[c.B, _, *_]), + "CAM_FRONT_LEFT": Tensor(shape=[c.B, _, *_]), + "CAM_FRONT_RIGHT": Tensor(shape=[c.B, _, *_]), + "mcap//CAM_FRONT/image_rect_compressed/_idx_": Tensor(shape=[c.B, _]), "mcap//CAM_FRONT/image_rect_compressed/log_time": Tensor( - shape=[c.B, c.S] + shape=[c.B, _] ), "mcap//CAM_FRONT_LEFT/image_rect_compressed/_idx_": Tensor( - shape=[c.B, c.S] + shape=[c.B, _] ), "mcap//CAM_FRONT_RIGHT/image_rect_compressed/_idx_": Tensor( - shape=[c.B, c.S] + shape=[c.B, _] ), - "mcap//odom/vel.x": Tensor(shape=[c.B, c.S]), + "mcap//odom/vel.x": Tensor(shape=[c.B, _]), **data_rest, }, "meta": { @@ -132,29 +128,27 @@ def test_nuscenes_rrd() -> None: dataloader = instantiate(cfg.dataloader) - c = SimpleNamespace( - B=cfg.dataloader.batch_size, S=cfg.dataloader.dataset.sample_builder.length - ) + c = SimpleNamespace(B=cfg.dataloader.batch_size) batch = next(iter(dataloader)) match batch.to_dict(): case { "data": { - "CAM_FRONT": Tensor(shape=[c.B, c.S, *_]), - "CAM_FRONT_LEFT": Tensor(shape=[c.B, c.S, *_]), - "CAM_FRONT_RIGHT": Tensor(shape=[c.B, c.S, *_]), + "CAM_FRONT": Tensor(shape=[c.B, _, *_]), + "CAM_FRONT_LEFT": Tensor(shape=[c.B, _, *_]), + "CAM_FRONT_RIGHT": Tensor(shape=[c.B, _, *_]), "rrd//world/ego_vehicle/CAM_FRONT/timestamp": Tensor( - shape=[c.B, c.S, *_] + shape=[c.B, _, *_] ), - "rrd//world/ego_vehicle/CAM_FRONT/_idx_": Tensor(shape=[c.B, c.S, *_]), + "rrd//world/ego_vehicle/CAM_FRONT/_idx_": Tensor(shape=[c.B, _, *_]), "rrd//world/ego_vehicle/CAM_FRONT_LEFT/_idx_": Tensor( - shape=[c.B, c.S, *_] + shape=[c.B, _, *_] ), "rrd//world/ego_vehicle/CAM_FRONT_RIGHT/_idx_": Tensor( - shape=[c.B, c.S, *_] + shape=[c.B, _, *_] ), "rrd//world/ego_vehicle/LIDAR_TOP/Position3D": Tensor( - shape=[c.B, c.S, *_] + shape=[c.B, _, *_] ), **data_rest, }, @@ -193,31 +187,27 @@ def test_yaak() -> None: dataloader = instantiate(cfg.dataloader) - c = SimpleNamespace( - B=cfg.dataloader.batch_size, S=cfg.dataloader.dataset.sample_builder.length - ) + c = SimpleNamespace(B=cfg.dataloader.batch_size) batch = next(iter(dataloader)) match batch.to_dict(): case { "data": { - "cam_front_left": Tensor(shape=[c.B, c.S, *_]), - "cam_left_backward": Tensor(shape=[c.B, c.S, *_]), - "cam_right_backward": Tensor(shape=[c.B, c.S, *_]), - "meta/ImageMetadata.cam_front_left/frame_idx": Tensor(shape=[c.B, c.S]), - "meta/ImageMetadata.cam_front_left/time_stamp": Tensor( - shape=[c.B, c.S] - ), + "cam_front_left": Tensor(shape=[c.B, _, *_]), + "cam_left_backward": Tensor(shape=[c.B, _, *_]), + "cam_right_backward": Tensor(shape=[c.B, _, *_]), + "meta/ImageMetadata.cam_front_left/frame_idx": Tensor(shape=[c.B, _]), + "meta/ImageMetadata.cam_front_left/time_stamp": Tensor(shape=[c.B, _]), "meta/ImageMetadata.cam_left_backward/frame_idx": Tensor( - shape=[c.B, c.S] + shape=[c.B, _] ), "meta/ImageMetadata.cam_right_backward/frame_idx": Tensor( - shape=[c.B, c.S] + shape=[c.B, _] ), - "meta/VehicleMotion/gear": Tensor(shape=[c.B, c.S]), - "meta/VehicleMotion/speed": Tensor(shape=[c.B, c.S]), - "mcap//ai/safety_score/clip.end_timestamp": Tensor(shape=[c.B, c.S]), - "mcap//ai/safety_score/score": Tensor(shape=[c.B, c.S]), + "meta/VehicleMotion/gear": Tensor(shape=[c.B, _]), + "meta/VehicleMotion/speed": Tensor(shape=[c.B, _]), + "mcap//ai/safety_score/clip.end_timestamp": Tensor(shape=[c.B, _]), + "mcap//ai/safety_score/score": Tensor(shape=[c.B, _]), **data_rest, }, "meta": { @@ -246,34 +236,32 @@ def test_zod() -> None: with initialize(version_base=None, config_path=CONFIG_PATH): cfg = compose( "visualize", - overrides=["dataset=zod", "logger=rerun/zod", f"+data_dir={DATA_DIR}"], + overrides=["dataset=zod", "logger=rerun/zod", f"+data_dir={DATA_DIR}/zod"], ) dataloader = instantiate(cfg.dataloader) - c = SimpleNamespace( - B=cfg.dataloader.batch_size, S=cfg.dataloader.dataset.sample_builder.length - ) + c = SimpleNamespace(B=cfg.dataloader.batch_size) batch = next(iter(dataloader)) match batch.to_dict(): case { "data": { - "camera_front_blur": Tensor(shape=[c.B, c.S, *_]), - "camera_front_blur/timestamp": Tensor(shape=[c.B, c.S, *_]), - "lidar_velodyne": Tensor(shape=[c.B, c.S, *_]), - "lidar_velodyne/timestamp": Tensor(shape=[c.B, c.S, *_]), + "camera_front_blur": Tensor(shape=[c.B, _, *_]), + "camera_front_blur/timestamp": Tensor(shape=[c.B, _, *_]), + "lidar_velodyne": Tensor(shape=[c.B, _, *_]), + "lidar_velodyne/timestamp": Tensor(shape=[c.B, _, *_]), "vehicle_data/ego_vehicle_controls/acceleration_pedal/ratio/unitless/value": Tensor( # noqa: E501 - shape=[c.B, c.S, *_] + shape=[c.B, _, *_] ), "vehicle_data/ego_vehicle_controls/steering_wheel_angle/angle/radians/value": Tensor( # noqa: E501 - shape=[c.B, c.S, *_] + shape=[c.B, _, *_] ), "vehicle_data/ego_vehicle_controls/timestamp/nanoseconds/value": Tensor( - shape=[c.B, c.S, *_] + shape=[c.B, _, *_] ), "vehicle_data/satellite/speed/meters_per_second/value": Tensor( - shape=[c.B, c.S, *_] + shape=[c.B, _, *_] ), **data_rest, },