diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af3802c..7cf807a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: validate-pyproject - repo: https://github.com/crate-ci/typos - rev: v1.24.6 + rev: v1.26.0 hooks: - id: typos @@ -24,14 +24,14 @@ repos: exclude: examples/config - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.8 + rev: v0.6.9 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.18.0 + rev: 1.18.3 hooks: - id: basedpyright @@ -40,7 +40,7 @@ repos: - id: just-format name: just-format language: system - stages: [commit] + stages: [pre-commit] entry: just --fmt --unstable pass_filenames: false always_run: true @@ -48,7 +48,7 @@ repos: - id: generate-example-config name: generate-example-config language: system - stages: [commit] + stages: [pre-commit] entry: just generate-example-config pass_filenames: false always_run: true diff --git a/README.md b/README.md index 01f6e39..fc773dc 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ dataset: NuScenes-v1.0-mini-scene-0103: frame: /CAM_FRONT/image_rect_compressed: - index_column: /CAM_FRONT/image_rect_compressed/idx + index_column: /CAM_FRONT/image_rect_compressed/_idx_ reader: _target_: rbyte.io.frame.mcap.McapFrameReader path: data/NuScenes-v1.0-mini-scene-0103.mcap @@ -121,7 +121,7 @@ dataset: frame_decoder: ${frame_decoder} /CAM_FRONT_LEFT/image_rect_compressed: - index_column: /CAM_FRONT_LEFT/image_rect_compressed/idx + index_column: /CAM_FRONT_LEFT/image_rect_compressed/_idx_ reader: _target_: rbyte.io.frame.mcap.McapFrameReader path: data/NuScenes-v1.0-mini-scene-0103.mcap @@ -130,7 +130,7 @@ dataset: frame_decoder: ${frame_decoder} /CAM_FRONT_RIGHT/image_rect_compressed: - index_column: /CAM_FRONT_RIGHT/image_rect_compressed/idx + index_column: /CAM_FRONT_RIGHT/image_rect_compressed/_idx_ reader: _target_: rbyte.io.frame.mcap.McapFrameReader path: data/NuScenes-v1.0-mini-scene-0103.mcap @@ -151,22 +151,22 @@ dataset: - rbyte.utils.mcap.McapJsonDecoderFactory fields: /CAM_FRONT/image_rect_compressed: + _idx_: log_time: _target_: polars.Datetime time_unit: ns - idx: null /CAM_FRONT_LEFT/image_rect_compressed: + _idx_: log_time: _target_: polars.Datetime time_unit: ns - idx: null /CAM_FRONT_RIGHT/image_rect_compressed: + _idx_: log_time: _target_: polars.Datetime time_unit: ns - idx: null /odom: log_time: @@ -175,7 +175,7 @@ dataset: vel.x: null merger: - _target_: rbyte.io.table.TableMerger + _target_: rbyte.io.table.TableAligner separator: / merge: /CAM_FRONT/image_rect_compressed: @@ -183,20 +183,20 @@ dataset: method: ref /CAM_FRONT_LEFT/image_rect_compressed: - log_time: - method: ref - idx: + _idx_: method: asof tolerance: 10ms strategy: nearest - - /CAM_FRONT_RIGHT/image_rect_compressed: log_time: method: ref - idx: + + /CAM_FRONT_RIGHT/image_rect_compressed: + _idx_: method: asof tolerance: 10ms strategy: nearest + log_time: + method: ref /odom: log_time: @@ -211,7 +211,7 @@ dataset: sample_builder: _target_: rbyte.sample.builder.GreedySampleTableBuilder - index_column: /CAM_FRONT/image_rect_compressed/idx + index_column: /CAM_FRONT/image_rect_compressed/_idx_ frame_decoder: _target_: simplejpeg.decode_jpeg @@ -252,11 +252,11 @@ Batch( is_shared=False), table=TensorDict( fields={ - /CAM_FRONT/image_rect_compressed/idx: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /CAM_FRONT/image_rect_compressed/_idx_: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), /CAM_FRONT/image_rect_compressed/log_time: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), - /CAM_FRONT_LEFT/image_rect_compressed/idx: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /CAM_FRONT_LEFT/image_rect_compressed/_idx_: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), /CAM_FRONT_LEFT/image_rect_compressed/log_time: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), - /CAM_FRONT_RIGHT/image_rect_compressed/idx: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /CAM_FRONT_RIGHT/image_rect_compressed/_idx_: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), /CAM_FRONT_RIGHT/image_rect_compressed/log_time: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), /odom/vel.x: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([1]), @@ -273,24 +273,24 @@ logger: _target_: rbyte.viz.loggers.RerunLogger schema: frame: - /CAM_FRONT/image_rect_compressed: - rerun.components.ImageBufferBatch: + /CAM_FRONT/image_rect_compressed: + Image: color_model: RGB /CAM_FRONT_LEFT/image_rect_compressed: - rerun.components.ImageBufferBatch: + Image: color_model: RGB /CAM_FRONT_RIGHT/image_rect_compressed: - rerun.components.ImageBufferBatch: + Image: color_model: RGB table: - /CAM_FRONT/image_rect_compressed/log_time: rerun.TimeNanosColumn - /CAM_FRONT/image_rect_compressed/idx: rerun.TimeSequenceColumn - /CAM_FRONT_LEFT/image_rect_compressed/idx: rerun.TimeSequenceColumn - /CAM_FRONT_RIGHT/image_rect_compressed/idx: rerun.TimeSequenceColumn - /odom/vel.x: rerun.components.ScalarBatch + /CAM_FRONT/image_rect_compressed/log_time: TimeNanosColumn + /CAM_FRONT/image_rect_compressed/_idx_: TimeSequenceColumn + /CAM_FRONT_LEFT/image_rect_compressed/_idx_: TimeSequenceColumn + /CAM_FRONT_RIGHT/image_rect_compressed/_idx_: TimeSequenceColumn + /odom/vel.x: Scalar ``` Visualize the dataset: diff --git a/examples/config_templates/build_table.yaml b/examples/config_templates/build_table.yaml index 4c743ff..70aaac3 100644 --- a/examples/config_templates/build_table.yaml +++ b/examples/config_templates/build_table.yaml @@ -1,13 +1,10 @@ --- defaults: - table_builder: !!null + - table_writer: !!null - _self_ path: ??? -writer: - _target_: polars.DataFrame.write_csv - _partial_: true - file: ??? hydra: output_subdir: !!null diff --git a/examples/config_templates/dataset/hdf5.yaml b/examples/config_templates/dataset/hdf5.yaml new file mode 100644 index 0000000..64054f6 --- /dev/null +++ b/examples/config_templates/dataset/hdf5.yaml @@ -0,0 +1,64 @@ +#! https://sites.google.com/view/il-for-mm/datasets#h.cq0r3rd5nr9m + +#@yaml/map-key-override +#@yaml/text-templated-strings + +#@ inputs = { +#@ "table_setup_from_dishwasher_sample": [ +#@ "/data/demo_0", +#@ "/data/demo_1", +#@ "/data/demo_10", +#@ "/data/demo_101", +#@ "/data/demo_102", +#@ ] +#@ } + +#@ frame_keys = [ +#@ '/obs/rgb', +#@ '/obs/depth', +#@ ] +--- +_target_: rbyte.Dataset +_convert_: all +_recursive_: false +inputs: + #@ for input_id, input_keys in inputs.items(): + #@ for input_key in input_keys: + (@=input_id@)(@=input_key@): + frame: + #@ for frame_key in frame_keys: + (@=frame_key@): + index_column: _idx_ + reader: + _target_: rbyte.io.frame.hdf5.Hdf5FrameReader + path: "${data_dir}/(@=input_id@).hdf5" + key: (@=input_key@)(@=frame_key@) + #@ end + + table: + path: "${data_dir}/(@=input_id@).hdf5" + builder: + _target_: rbyte.io.table.TableBuilder + _convert_: all + reader: + _target_: rbyte.io.table.hdf5.Hdf5TableReader + _recursive_: false + fields: + /data/demo_0: + _idx_: + obs/object: + task_successes: + + merger: + _target_: rbyte.io.table.TableConcater + separator: "/" + #@ end + #@ end + +sample_builder: + _target_: rbyte.sample.builder.GreedySampleTableBuilder + index_column: _idx_ + length: 1 + stride: 1 + min_step: 1 + filter: !!null diff --git a/examples/config_templates/dataset/mcap.yaml b/examples/config_templates/dataset/mcap.yaml index 6631a66..fe798b0 100644 --- a/examples/config_templates/dataset/mcap.yaml +++ b/examples/config_templates/dataset/mcap.yaml @@ -21,7 +21,7 @@ inputs: frame: #@ for topic in camera_topics: (@=topic@): - index_column: (@=topic@)/idx + index_column: (@=topic@)/_idx_ reader: _target_: rbyte.io.frame.mcap.McapFrameReader path: "${data_dir}/(@=input_id@).mcap" @@ -54,7 +54,7 @@ inputs: _target_: polars.Datetime time_unit: ns - idx: + _idx_: #@ end /odom: @@ -64,8 +64,8 @@ inputs: vel.x: merger: - _target_: rbyte.io.table.TableMerger - separator: / + _target_: rbyte.io.table.TableAligner + separator: "/" merge: (@=camera_topics[0]@): log_time: @@ -75,7 +75,7 @@ inputs: (@=topic@): log_time: method: ref - idx: + _idx_: method: asof tolerance: 10ms strategy: nearest @@ -98,7 +98,7 @@ inputs: sample_builder: _target_: rbyte.sample.builder.GreedySampleTableBuilder - index_column: (@=camera_topics[0]@)/idx + index_column: (@=camera_topics[0]@)/_idx_ length: 1 stride: 1 min_step: 1 diff --git a/examples/config_templates/dataset/yaak.yaml b/examples/config_templates/dataset/yaak.yaml index 37ce805..e81168f 100644 --- a/examples/config_templates/dataset/yaak.yaml +++ b/examples/config_templates/dataset/yaak.yaml @@ -64,7 +64,7 @@ inputs: categories: ["0", "1", "2", "3"] merger: - _target_: rbyte.io.table.TableMerger + _target_: rbyte.io.table.TableAligner separator: "." merge: ImageMetadata.(@=cameras[0]@): diff --git a/examples/config_templates/frame_reader/hdf5.yaml b/examples/config_templates/frame_reader/hdf5.yaml new file mode 100644 index 0000000..1a0a769 --- /dev/null +++ b/examples/config_templates/frame_reader/hdf5.yaml @@ -0,0 +1,5 @@ +--- +_target_: rbyte.io.frame.Hdf5FrameReader +_recursive_: true +path: ??? +key: ??? diff --git a/examples/config_templates/logger/rerun/carla.yaml b/examples/config_templates/logger/rerun/carla.yaml index c4eec72..d02d6b9 100644 --- a/examples/config_templates/logger/rerun/carla.yaml +++ b/examples/config_templates/logger/rerun/carla.yaml @@ -12,14 +12,14 @@ schema: frame: #@ for camera in cameras: (@=camera@): - rerun.components.ImageBufferBatch: + Image: color_model: RGB #@ end table: - frame_idx: rerun.TimeSequenceColumn - control.brake: rerun.components.ScalarBatch - control.steer: rerun.components.ScalarBatch - control.throttle: rerun.components.ScalarBatch - state.acceleration.value: rerun.components.ScalarBatch - state.velocity.value: rerun.components.ScalarBatch + frame_idx: TimeSequenceColumn + control.brake: Scalar + control.steer: Scalar + control.throttle: Scalar + state.acceleration.value: Scalar + state.velocity.value: Scalar diff --git a/examples/config_templates/logger/rerun/hdf5.yaml b/examples/config_templates/logger/rerun/hdf5.yaml new file mode 100644 index 0000000..f287149 --- /dev/null +++ b/examples/config_templates/logger/rerun/hdf5.yaml @@ -0,0 +1,23 @@ +#@yaml/text-templated-strings + +#@ camera_topics = [ +#@ '/CAM_FRONT/image_rect_compressed', +#@ '/CAM_FRONT_LEFT/image_rect_compressed', +#@ '/CAM_FRONT_RIGHT/image_rect_compressed', +#@ ] +--- +_target_: rbyte.viz.loggers.RerunLogger +schema: + frame: + /obs/rgb: + Image: + color_model: RGB + + /obs/depth: + DepthImage: + color_model: L + + table: + _idx_: TimeSequenceColumn + obs/object: Points3D + task_successes: Scalar diff --git a/examples/config_templates/logger/rerun/mcap.yaml b/examples/config_templates/logger/rerun/mcap.yaml index 86b53f9..ee33859 100644 --- a/examples/config_templates/logger/rerun/mcap.yaml +++ b/examples/config_templates/logger/rerun/mcap.yaml @@ -9,15 +9,15 @@ _target_: rbyte.viz.loggers.RerunLogger schema: frame: - #@ for topic in camera_topics: - (@=topic@): - rerun.components.ImageBufferBatch: + #@ for camera_topic in camera_topics: + (@=camera_topic@): + Image: color_model: RGB #@ end table: - (@=camera_topics[0]@)/log_time: rerun.TimeNanosColumn - #@ for topic in camera_topics: - (@=topic@)/idx: rerun.TimeSequenceColumn + (@=camera_topics[0]@)/log_time: TimeNanosColumn + #@ for camera_topic in camera_topics: + (@=camera_topic@)/_idx_: TimeSequenceColumn #@ end - /odom/vel.x: rerun.components.ScalarBatch + /odom/vel.x: Scalar diff --git a/examples/config_templates/logger/rerun/yaak.yaml b/examples/config_templates/logger/rerun/yaak.yaml index e4a3d7d..4bba368 100644 --- a/examples/config_templates/logger/rerun/yaak.yaml +++ b/examples/config_templates/logger/rerun/yaak.yaml @@ -12,14 +12,14 @@ schema: frame: #@ for camera in cameras: (@=camera@): - rerun.components.ImageBufferBatch: + Image: color_model: RGB #@ end table: #@ for camera in cameras: - ImageMetadata.(@=camera@).frame_idx: rerun.TimeSequenceColumn - ImageMetadata.(@=camera@).time_stamp: rerun.TimeNanosColumn + ImageMetadata.(@=camera@).frame_idx: TimeSequenceColumn + ImageMetadata.(@=camera@).time_stamp: TimeNanosColumn #@ end - VehicleMotion.time_stamp: rerun.TimeNanosColumn - VehicleMotion.speed: rerun.components.ScalarBatch + VehicleMotion.time_stamp: TimeNanosColumn + VehicleMotion.speed: Scalar diff --git a/examples/config_templates/read_frames.yaml b/examples/config_templates/read_frames.yaml index 2c1ac51..660dd26 100644 --- a/examples/config_templates/read_frames.yaml +++ b/examples/config_templates/read_frames.yaml @@ -4,7 +4,11 @@ defaults: - _self_ batch_size: 1 +application_id: rbyte-read-frames entity_path: ??? +frame_config: + Image: + color_model: RGB hydra: output_subdir: !!null diff --git a/examples/config_templates/table_builder/hdf5.yaml b/examples/config_templates/table_builder/hdf5.yaml new file mode 100644 index 0000000..17ad2e5 --- /dev/null +++ b/examples/config_templates/table_builder/hdf5.yaml @@ -0,0 +1,23 @@ +--- +_target_: rbyte.io.table.TableBuilder +_convert_: all +reader: + _target_: rbyte.io.table.hdf5.Hdf5TableReader + _recursive_: false + fields: + /data/demo_0: + _idx_: + actions: + dones: + obs/gt_nav: + obs/object: + obs/proprio: + obs/proprio_nav: + obs/scan: + rewards: + states: + task_successes: + +merger: + _target_: rbyte.io.table.TableConcater + separator: "/" diff --git a/examples/config_templates/table_builder/mcap.yaml b/examples/config_templates/table_builder/mcap.yaml index 8bfc867..32c5cc3 100644 --- a/examples/config_templates/table_builder/mcap.yaml +++ b/examples/config_templates/table_builder/mcap.yaml @@ -21,7 +21,7 @@ reader: _target_: polars.Datetime time_unit: ns - idx: + _idx_: #@ end /odom: @@ -31,7 +31,7 @@ reader: vel.x: merger: - _target_: rbyte.io.table.TableMerger + _target_: rbyte.io.table.TableAligner separator: / merge: (@=camera_topics[0]@): @@ -42,7 +42,7 @@ merger: (@=topic@): log_time: method: ref - idx: + _idx_: method: asof tolerance: 10ms strategy: nearest diff --git a/examples/config_templates/table_builder/yaak.yaml b/examples/config_templates/table_builder/yaak.yaml index 516dce6..4e8c4ca 100644 --- a/examples/config_templates/table_builder/yaak.yaml +++ b/examples/config_templates/table_builder/yaak.yaml @@ -41,7 +41,7 @@ reader: categories: ["0", "1", "2", "3"] merger: - _target_: rbyte.io.table.TableMerger + _target_: rbyte.io.table.TableAligner separator: "." merge: ImageMetadata.(@=cameras[0]@): diff --git a/examples/config_templates/table_writer/csv.yaml b/examples/config_templates/table_writer/csv.yaml new file mode 100644 index 0000000..8aca7e2 --- /dev/null +++ b/examples/config_templates/table_writer/csv.yaml @@ -0,0 +1,4 @@ +--- +_target_: polars.DataFrame.write_csv +_partial_: true +file: ??? diff --git a/examples/config_templates/table_writer/parquet.yaml b/examples/config_templates/table_writer/parquet.yaml new file mode 100644 index 0000000..a58f047 --- /dev/null +++ b/examples/config_templates/table_writer/parquet.yaml @@ -0,0 +1,4 @@ +--- +_target_: polars.DataFrame.write_parquet +_partial_: true +file: ??? diff --git a/pyproject.toml b/pyproject.toml index 602f285..dcb780d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,11 +7,11 @@ maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] dependencies = [ "tensordict @ git+https://github.com/pytorch/tensordict.git@85b6b81", "torch>=2.4.1", - "polars>=1.8.2", + "polars>=1.9.0", "pydantic>=2.9.2", "more-itertools>=10.5.0", "hydra-core>=1.3.2", - "optree>=0.12.1", + "optree>=0.13.0", "cachetools>=5.5.0", "diskcache>=5.6.3", "jaxtyping>=0.2.34", @@ -42,9 +42,10 @@ mcap = [ "mcap-ros2-support>=0.5.3", "python-box>=7.2.0", ] -yaak = ["protobuf", "ptars>=0.0.2rc2"] +yaak = ["protobuf", "ptars>=0.0.2"] jpeg = ["simplejpeg>=1.7.6"] video = ["video-reader-rs>=0.1.4"] +hdf5 = ["h5py>=3.12.1"] [project.scripts] rbyte-build-table = 'rbyte.scripts.build_table:main' @@ -64,6 +65,8 @@ dev-dependencies = [ "wat-inspector>=0.4.0", "lovely-tensors>=0.1.17", "pudb>=2024.1.2", + "ipython>=8.28.0", + "ipython-autoimport>=0.5", ] [tool.uv.sources] diff --git a/src/rbyte/config/base.py b/src/rbyte/config/base.py index 589a12e..3b5f15a 100644 --- a/src/rbyte/config/base.py +++ b/src/rbyte/config/base.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TypeVar +from typing import Literal, TypeVar from hydra.utils import instantiate from pydantic import BaseModel as _BaseModel @@ -22,6 +22,11 @@ class HydraConfig[T](BaseModel): model_config = ConfigDict(extra="allow") target: ImportString[type[T]] = Field(alias="_target_") + recursive: bool = Field(alias="_recursive_", default=True) + convert: Literal["none", "partial", "object", "all"] = Field( + alias="_convert_", default="none" + ) + partial: bool = Field(alias="_partial_", default=False) def instantiate(self, **kwargs: object) -> T: return instantiate(self.model_dump(by_alias=True), **kwargs) diff --git a/src/rbyte/io/frame/__init__.py b/src/rbyte/io/frame/__init__.py index 14cb1ec..97a88a5 100644 --- a/src/rbyte/io/frame/__init__.py +++ b/src/rbyte/io/frame/__init__.py @@ -15,3 +15,10 @@ pass else: __all__ += ["McapFrameReader"] + +try: + from .hdf5 import Hdf5FrameReader +except ImportError: + pass +else: + __all__ += ["Hdf5FrameReader"] diff --git a/src/rbyte/io/frame/hdf5/__init__.py b/src/rbyte/io/frame/hdf5/__init__.py new file mode 100644 index 0000000..125cf3d --- /dev/null +++ b/src/rbyte/io/frame/hdf5/__init__.py @@ -0,0 +1,3 @@ +from .reader import Hdf5FrameReader + +__all__ = ["Hdf5FrameReader"] diff --git a/src/rbyte/io/frame/hdf5/reader.py b/src/rbyte/io/frame/hdf5/reader.py new file mode 100644 index 0000000..3613642 --- /dev/null +++ b/src/rbyte/io/frame/hdf5/reader.py @@ -0,0 +1,25 @@ +from collections.abc import Iterable, Sequence +from typing import cast, override + +import h5py +import torch +from jaxtyping import UInt8 +from pydantic import FilePath, validate_call +from torch import Tensor + +from rbyte.io.frame.base import FrameReader + + +class Hdf5FrameReader(FrameReader): + @validate_call + def __init__(self, path: FilePath, key: str) -> None: + file = h5py.File(path) + self._dataset = cast(h5py.Dataset, file[key]) + + @override + def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: + return torch.from_numpy(self._dataset[indexes]) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] + + @override + def get_available_indexes(self) -> Sequence[int]: + return range(len(self._dataset)) diff --git a/src/rbyte/io/table/__init__.py b/src/rbyte/io/table/__init__.py index 3c89ca5..e5d6276 100644 --- a/src/rbyte/io/table/__init__.py +++ b/src/rbyte/io/table/__init__.py @@ -1,4 +1,5 @@ +from .aligner import TableAligner from .builder import TableBuilder -from .merger import TableMerger +from .concater import TableConcater -__all__ = ["TableBuilder", "TableMerger"] +__all__ = ["TableAligner", "TableBuilder", "TableConcater"] diff --git a/src/rbyte/io/table/merger.py b/src/rbyte/io/table/aligner.py similarity index 98% rename from src/rbyte/io/table/merger.py rename to src/rbyte/io/table/aligner.py index 3a010c0..3b1eb79 100644 --- a/src/rbyte/io/table/merger.py +++ b/src/rbyte/io/table/aligner.py @@ -39,7 +39,7 @@ class InterpColumnMergeConfig(BaseModel): class Config(BaseModel): merge: OrderedDict[str, Mapping[str, MergeConfig]] - separator: Annotated[str, StringConstraints(strip_whitespace=True)] = "." + separator: Annotated[str, StringConstraints(strip_whitespace=True)] = "/" @model_validator(mode="after") def validate_refs(self) -> Self: @@ -72,7 +72,7 @@ def ref_columns(self) -> Mapping[str, str]: } -class TableMerger(TableMergerBase, Hashable): +class TableAligner(TableMergerBase, Hashable): def __init__(self, **kwargs: object) -> None: self._config = Config.model_validate(kwargs) diff --git a/src/rbyte/io/table/concater.py b/src/rbyte/io/table/concater.py new file mode 100644 index 0000000..62c2ca2 --- /dev/null +++ b/src/rbyte/io/table/concater.py @@ -0,0 +1,33 @@ +import json +from collections.abc import Hashable, Mapping +from typing import Annotated, override + +import polars as pl +from polars._typing import ConcatMethod +from pydantic import StringConstraints +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config import BaseModel +from rbyte.io.table.base import TableMergerBase + + +class Config(BaseModel): + separator: Annotated[str, StringConstraints(strip_whitespace=True)] = "/" + method: ConcatMethod = "horizontal" + + +class TableConcater(TableMergerBase, Hashable): + def __init__(self, **kwargs: object) -> None: + self._config = Config.model_validate(kwargs) + + @override + def merge(self, src: Mapping[str, pl.DataFrame]) -> pl.DataFrame: + return pl.concat(src.values(), how=self._config.method) + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) diff --git a/src/rbyte/io/table/hdf5/__init__.py b/src/rbyte/io/table/hdf5/__init__.py new file mode 100644 index 0000000..ccff9e3 --- /dev/null +++ b/src/rbyte/io/table/hdf5/__init__.py @@ -0,0 +1,3 @@ +from .reader import Hdf5TableReader + +__all__ = ["Hdf5TableReader"] diff --git a/src/rbyte/io/table/hdf5/reader.py b/src/rbyte/io/table/hdf5/reader.py new file mode 100644 index 0000000..bf01ec0 --- /dev/null +++ b/src/rbyte/io/table/hdf5/reader.py @@ -0,0 +1,102 @@ +import json +from collections.abc import Hashable, Mapping +from enum import StrEnum, unique +from functools import cached_property +from os import PathLike +from typing import Any, cast, override + +import numpy.typing as npt +import polars as pl +from h5py import Dataset, File, Group +from polars._typing import PolarsDataType +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) +from pydantic import ConfigDict, ImportString +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config import BaseModel +from rbyte.config.base import HydraConfig +from rbyte.io.table.base import TableReaderBase + + +class Config(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + fields: Mapping[ + str, + Mapping[str, HydraConfig[PolarsDataType] | ImportString[PolarsDataType] | None], + ] + + +@unique +class SpecialFields(StrEnum): + idx = "_idx_" + + +class Hdf5TableReader(TableReaderBase, Hashable): + def __init__(self, **kwargs: object) -> None: + self._config = Config.model_validate(kwargs) + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) + + @override + def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + dfs: Mapping[str, pl.DataFrame] = {} + + with File(path) as f: + for group_key, schema in self.schemas.items(): + match group := f[group_key]: + case Group(): + series: list[pl.Series] = [] + for name, dtype in schema.items(): + match name: + case SpecialFields.idx: + pass + + case _: + match dataset := group[name]: + case Dataset(): + values = cast(npt.NDArray[Any], dataset[:]) + series.append( + pl.Series( + name=name, + values=values, + dtype=dtype, + ) + ) + + case _: + raise NotImplementedError + + df = pl.DataFrame(data=series) # pyright: ignore[reportGeneralTypeIssues] + if (idx_name := SpecialFields.idx) in schema: + df = df.with_row_index(idx_name).cast({ + idx_name: schema[idx_name] or pl.UInt32 + }) + + dfs[group_key] = df + + case _: + raise NotImplementedError + + return dfs + + @cached_property + def schemas(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: + return { + group_key: { + dataset_key: leaf.instantiate() + if isinstance(leaf, HydraConfig) + else leaf + for dataset_key, leaf in fields.items() + } + for group_key, fields in self._config.fields.items() + } diff --git a/src/rbyte/io/table/mcap/reader.py b/src/rbyte/io/table/mcap/reader.py index f6f97e9..7d2acfc 100644 --- a/src/rbyte/io/table/mcap/reader.py +++ b/src/rbyte/io/table/mcap/reader.py @@ -12,6 +12,11 @@ import polars as pl from mcap.decoder import DecoderFactory from mcap.reader import DecodedMessageTuple, SeekingReader +from polars._typing import PolarsDataType +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) from pydantic import ( ConfigDict, ImportString, @@ -30,9 +35,6 @@ logger = get_logger(__name__) -PolarsDataType = pl.DataType | pl.DataTypeClass - - class Config(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -64,7 +66,7 @@ class RowValues(NamedTuple): class SpecialFields(StrEnum): log_time = "log_time" publish_time = "publish_time" - idx = "idx" + idx = "_idx_" class McapTableReader(TableReaderBase, Hashable): @@ -150,7 +152,7 @@ def __hash__(self) -> int: return digest(config_str) @cached_property - def schemas(self) -> dict[str, dict[str, PolarsDataType | None]]: + def schemas(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: return { topic: { path: leaf.instantiate() if isinstance(leaf, HydraConfig) else leaf diff --git a/src/rbyte/io/table/yaak/reader.py b/src/rbyte/io/table/yaak/reader.py index 9bfa1bb..0605950 100644 --- a/src/rbyte/io/table/yaak/reader.py +++ b/src/rbyte/io/table/yaak/reader.py @@ -11,6 +11,11 @@ import polars as pl from google.protobuf.message import Message from optree import tree_map +from polars._typing import PolarsDataType +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) from ptars import HandlerPool from pydantic import ConfigDict, ImportString from structlog import get_logger @@ -26,9 +31,6 @@ logger = get_logger(__name__) -PolarsDataType = pl.DataType | pl.DataTypeClass - - class Config(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/rbyte/scripts/build_table.py b/src/rbyte/scripts/build_table.py index c33a14c..df8b10a 100644 --- a/src/rbyte/scripts/build_table.py +++ b/src/rbyte/scripts/build_table.py @@ -14,10 +14,10 @@ @hydra.main(version_base=None) def main(config: DictConfig) -> None: table_builder = cast(TableBuilderBase, instantiate(config.table_builder)) - writer = cast(Callable[[Table], None], instantiate(config.writer)) + table_writer = cast(Callable[[Table], None], instantiate(config.table_writer)) table = table_builder.build(config.path) - return writer(table) + return table_writer(table) if __name__ == "__main__": diff --git a/src/rbyte/scripts/read_frames.py b/src/rbyte/scripts/read_frames.py index dec71c2..bce80ac 100644 --- a/src/rbyte/scripts/read_frames.py +++ b/src/rbyte/scripts/read_frames.py @@ -1,15 +1,19 @@ -from typing import cast +from typing import Any, cast import hydra import more_itertools as mit +import numpy as np +import numpy.typing as npt import rerun as rr -import torch from hydra.utils import instantiate from omegaconf import DictConfig +from pydantic import TypeAdapter from structlog import get_logger +from structlog.contextvars import bound_contextvars from tqdm import tqdm from rbyte.io.frame.base import FrameReader +from rbyte.viz.loggers.rerun_logger import FrameConfig logger = get_logger(__name__) @@ -17,44 +21,72 @@ @hydra.main(version_base=None) def main(config: DictConfig) -> None: frame_reader = cast(FrameReader, instantiate(config.frame_reader)) - rr.init("rbyte", spawn=True) + frame_config = cast( + FrameConfig, TypeAdapter(FrameConfig).validate_python(config.frame_config) + ) + + rr.init(config.application_id, spawn=True) for frame_indexes in mit.chunked( tqdm(sorted(frame_reader.get_available_indexes())), config.batch_size, strict=False, ): - frames = frame_reader.read(frame_indexes) - match frames.shape, frames.dtype: - case ((_, height, width, 3), torch.uint8): - rr.log( - config.entity_path, - [ - rr.components.ImageFormat( + with bound_contextvars(frame_config=frame_config): + match frame_config: + case {rr.Image: image_format} | {rr.DepthImage: image_format}: + arr = cast( + npt.NDArray[Any], + frame_reader.read(frame_indexes).cpu().numpy(), # pyright: ignore[reportUnknownMemberType] + ) + with bound_contextvars(image_format=image_format, shape=arr.shape): + match ( + image_format.pixel_format, + image_format.color_model, + arr.shape, + ): + case None, rr.ColorModel(), (batch, height, width, _): + pass + + case rr.PixelFormat.NV12, None, (batch, dim, width): + height = int(dim / 1.5) + + case _: + logger.error("not implemented") + + raise NotImplementedError + + image_format = rr.components.ImageFormat( height=height, width=width, - color_model="RGB", - channel_datatype="U8", - ), - rr.Image.indicator(), - ], - static=True, - strict=True, - ) - - rr.send_columns( - config.entity_path, - times=[rr.TimeSequenceColumn("frame_index", frame_indexes)], - components=[ - rr.components.ImageBufferBatch( - frames.flatten(start_dim=1, end_dim=-1).numpy() # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + pixel_format=image_format.pixel_format, + color_model=image_format.color_model, + channel_datatype=rr.ChannelDatatype.from_np_dtype( + arr.dtype + ), ) - ], - strict=True, - ) - case _: - raise NotImplementedError + components = [ + mit.one(frame_config).indicator(), + rr.components.ImageFormatBatch([image_format] * batch), + rr.components.ImageBufferBatch( + arr.reshape(batch, -1).view(np.uint8) + ), + ] + + case _: + logger.error("not implemented") + + raise NotImplementedError + + times = [rr.TimeSequenceColumn("frame_index", frame_indexes)] + + rr.send_columns( + entity_path=config.entity_path, + times=times, + components=components, # pyright: ignore[reportArgumentType] + strict=True, + ) if __name__ == "__main__": diff --git a/src/rbyte/utils/dataframe/misc.py b/src/rbyte/utils/dataframe/misc.py index d9fc8b2..f857257 100644 --- a/src/rbyte/utils/dataframe/misc.py +++ b/src/rbyte/utils/dataframe/misc.py @@ -1,16 +1,16 @@ from collections.abc import Generator, Mapping import polars as pl -import polars._typing as plt +from polars._typing import PolarsDataType # TODO: https://github.com/pola-rs/polars/issues/12353 # noqa: FIX002 def unnest_all( - schema: Mapping[str, plt.PolarsDataType], separator: str = "." + schema: Mapping[str, PolarsDataType], separator: str = "." ) -> Generator[pl.Expr]: def _unnest( - schema: Mapping[str, plt.PolarsDataType], path: tuple[str, ...] = () - ) -> Generator[tuple[tuple[str, ...], plt.PolarsDataType]]: + schema: Mapping[str, PolarsDataType], path: tuple[str, ...] = () + ) -> Generator[tuple[tuple[str, ...], PolarsDataType]]: for name, dtype in schema.items(): match dtype: case pl.Struct(): diff --git a/src/rbyte/viz/loggers/rerun_logger.py b/src/rbyte/viz/loggers/rerun_logger.py index ecda909..852d468 100644 --- a/src/rbyte/viz/loggers/rerun_logger.py +++ b/src/rbyte/viz/loggers/rerun_logger.py @@ -11,29 +11,37 @@ runtime_checkable, ) +import more_itertools as mit +import numpy as np import numpy.typing as npt import rerun as rr from pydantic import ( BeforeValidator, ConfigDict, + Field, ImportString, model_validator, validate_call, ) -from rerun._baseclasses import ComponentBatchMixin # noqa: PLC2701 +from pydantic.types import AnyType +from rerun._baseclasses import Archetype # noqa: PLC2701 from rerun._send_columns import TimeColumnLike # noqa: PLC2701 +from structlog import get_logger +from structlog.contextvars import bound_contextvars from rbyte.batch import Batch from rbyte.config import BaseModel from .base import Logger +logger = get_logger(__name__) + @runtime_checkable class TimeColumn(TimeColumnLike, Protocol): ... -class ImageFormatConfig(BaseModel): +class ImageFormat(BaseModel): pixel_format: ( Annotated[rr.PixelFormat, BeforeValidator(rr.PixelFormat.auto)] | None ) = None @@ -51,19 +59,23 @@ def validate_model(self: Self) -> Self: return self -TableSchema = ( - ImportString[type[TimeColumn]] | ImportString[type[rr.components.ScalarBatch]] -) -FrameSchema = Mapping[ - ImportString[type[rr.components.ImageBufferBatch]], ImageFormatConfig +RerunImportString = Annotated[ + ImportString[AnyType], + BeforeValidator(lambda x: f"rerun.{x}" if not x.startswith("rerun.") else x), +] + +FrameConfig = Annotated[ + Mapping[RerunImportString[type[Archetype]], ImageFormat], Field(max_length=1) ] +TableConfig = RerunImportString[type[TimeColumn | Archetype]] + class Schema(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - frame: Mapping[str, FrameSchema] - table: Mapping[str, TableSchema] + frame: Mapping[str, FrameConfig] = Field(default_factory=dict) + table: Mapping[str, TableConfig] = Field(default_factory=dict) @cached_property def times(self) -> Mapping[tuple[Literal["table"], str], TimeColumn]: @@ -72,13 +84,9 @@ def times(self) -> Mapping[tuple[Literal["table"], str], TimeColumn]: } @cached_property - def components( - self, - ) -> Mapping[tuple[str, str], FrameSchema | type[ComponentBatchMixin]]: + def components(self) -> Mapping[tuple[str, str], FrameConfig | type[Archetype]]: return {("frame", k): v for k, v in self.frame.items()} | { - ("table", k): v - for k, v in self.table.items() - if issubclass(v, ComponentBatchMixin) + ("table", k): v for k, v in self.table.items() if issubclass(v, Archetype) } @@ -90,26 +98,17 @@ def __init__(self, schema: Schema) -> None: self._schema = schema @cache # noqa: B019 - def _get_recording(self, *, application_id: str) -> rr.RecordingStream: - with rr.new_recording( + def _get_recording(self, *, application_id: str) -> rr.RecordingStream: # noqa: PLR6301 + return rr.new_recording( application_id=application_id, spawn=True, make_default=True - ) as recording: - for k in self._schema.frame: - rr.log( - entity_path=f"frame/{k}", - entity=[rr.Image.indicator()], - static=True, - strict=True, - ) - - return recording + ) @override def log(self, batch_idx: int, batch: Batch) -> None: # NOTE: zip because batch.meta.input_id is NonTensorData and isn't indexed for input_id, sample in zip( # pyright: ignore[reportUnknownVariableType] - batch.get(k := ("meta", "input_id")), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] - batch.exclude(k), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] + batch.get(path := ("meta", "input_id")), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] + batch.exclude(path), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] strict=True, ): with self._get_recording(application_id=input_id): # pyright: ignore[reportUnknownArgumentType] @@ -118,51 +117,88 @@ def log(self, batch_idx: int, batch: Batch) -> None: for k, v in self._schema.times.items() ] - for k, v in self._schema.components.items(): - tensor = cast(npt.NDArray[Any], sample.get(k).cpu().numpy()) # pyright: ignore[reportUnknownMemberType] - match v: - case rr.components.ScalarBatch: - components = [v(tensor)] - - case { - rr.components.ImageBufferBatch: ImageFormatConfig( - pixel_format=pixel_format, color_model=color_model - ) - }: - match (pixel_format, color_model, tensor.shape): - case None, rr.ColorModel.RGB, ( - (b, h, w, 3) | (b, 3, h, w) + for path, schema in self._schema.components.items(): + with bound_contextvars(path=path, schema=schema): + arr = cast(npt.NDArray[Any], sample.get(path).cpu().numpy()) # pyright: ignore[reportUnknownMemberType] + match schema: + case rr.Scalar: + components = [ + schema.indicator(), + rr.components.ScalarBatch(arr), + ] + + case rr.Points3D: + components = [ + schema.indicator(), + rr.components.Position3DBatch(arr).partition( + arr.shape[0] + ), + ] + + case rr.Tensor: + components = [ + schema.indicator(), + rr.components.TensorDataBatch(arr), + ] + + case {rr.Image: image_format} | { + rr.DepthImage: image_format + }: + with bound_contextvars( + image_format=image_format, shape=arr.shape ): - image_format = rr.components.ImageFormat( - width=w, - height=h, - color_model=color_model, - channel_datatype=rr.ChannelDatatype.from_np_dtype( - tensor.dtype - ), - ) - - case rr.PixelFormat.NV12, None, (b, dim, w): - image_format = rr.components.ImageFormat( - width=w, - height=int(dim / 1.5), - pixel_format=pixel_format, - ) - - case _: - raise NotImplementedError - - components = [ - rr.components.ImageFormatBatch([image_format] * b), - rr.components.ImageBufferBatch(tensor.reshape(b, -1)), - ] - - case _: - raise NotImplementedError - - rr.send_columns( - entity_path="/".join(k), - times=times, - components=components, - strict=True, - ) + match ( + image_format.pixel_format, + image_format.color_model, + arr.shape, + ): + case None, rr.ColorModel(), ( + _batch, + height, + width, + _, + ): + pass + + case rr.PixelFormat.NV12, None, ( + _batch, + dim, + width, + ): + height = int(dim / 1.5) + + case _: + logger.error("not implemented") + + raise NotImplementedError + + image_format = rr.components.ImageFormat( + height=height, + width=width, + pixel_format=image_format.pixel_format, + color_model=image_format.color_model, + channel_datatype=rr.ChannelDatatype.from_np_dtype( + arr.dtype + ), + ) + components = [ + mit.one(schema).indicator(), + rr.components.ImageFormatBatch( + [image_format] * _batch + ), + rr.components.ImageBufferBatch( + arr.reshape(_batch, -1).view(np.uint8) + ), + ] + + case _: + logger.error("not implemented") + + raise NotImplementedError + + rr.send_columns( + entity_path="/".join(path), + times=times, + components=components, # pyright: ignore[reportArgumentType] + strict=True, + )