diff --git a/.gitignore b/.gitignore index 06f44dd..de04814 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ examples/config # lockfiles uv.lock + +.basedpyright diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b227e2a..af3802c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ fail_fast: true repos: - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.19 + rev: v0.20.2 hooks: - id: validate-pyproject @@ -24,14 +24,14 @@ repos: exclude: examples/config - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.7 + rev: v0.6.8 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.17.5 + rev: 1.18.0 hooks: - id: basedpyright diff --git a/README.md b/README.md index 869874c..36b8dc7 100644 --- a/README.md +++ b/README.md @@ -25,18 +25,18 @@ See [examples/config_templates](examples/config_templates) ([`ytt`](https://carv
nuScenes x mcap -1. Setup a new project with [`uv`](https://docs.astral.sh/uv/) +Setup a new project with [`uv`](https://docs.astral.sh/uv/): ```shell uv init nuscenes_mcap cd nuscenes_mcap uv add hydra-core omegaconf -uv add https://github.com/yaak-ai/rbyte/releases/latest/download/rbyte-0.2.0-py3-none-any.whl --extra mcap --extra jpeg --extra visualize +uv add https://github.com/yaak-ai/rbyte/releases/latest/download/rbyte-0.3.0-py3-none-any.whl --extra mcap --extra jpeg --extra visualize mkdir data ``` -2. Follow the guide at [foxglove/nuscenes2mcap](https://github.com/foxglove/nuscenes2mcap) and move the resulting `.mcap` files under `data/`. In this example we're using a subset of topics from `NuScenes-v1.0-mini-scene-0103.mcap`: +Follow the guide at [foxglove/nuscenes2mcap](https://github.com/foxglove/nuscenes2mcap) and move the resulting `.mcap` files under `data/`. In this example we're using a subset of topics from `NuScenes-v1.0-mini-scene-0103.mcap`: ```shell mcap info data/NuScenes-v1.0-mini-scene-0103.mcap library: nuscenes2mcap @@ -93,13 +93,13 @@ attachments: 0 metadata: 1 ``` -3. Create a `config.yaml` with the following: +Create a `config.yaml` to extract frames from three cameras + velocity, aligning everything to the first camera's timestamp: ```yaml --- dataloader: _target_: torch.utils.data.DataLoader dataset: ${dataset} - batch_size: 32 + batch_size: 1 collate_fn: _target_: rbyte.utils.dataloader.collate_identity _partial_: true @@ -111,50 +111,68 @@ dataset: inputs: NuScenes-v1.0-mini-scene-0103: frame: - CAM_FRONT: - index_column: /CAM_FRONT/image_rect_compressed/frame_idx + /CAM_FRONT/image_rect_compressed: + index_column: /CAM_FRONT/image_rect_compressed/idx reader: _target_: rbyte.io.frame.mcap.McapFrameReader path: data/NuScenes-v1.0-mini-scene-0103.mcap topic: /CAM_FRONT/image_rect_compressed - message_decoder_factory: ${message_decoder_factory} - frame_decoder: ${jpeg_decoder} + decoder_factory: mcap_protobuf.decoder.DecoderFactory + frame_decoder: ${frame_decoder} - CAM_FRONT_LEFT: - index_column: /CAM_FRONT_LEFT/image_rect_compressed/frame_idx + /CAM_FRONT_LEFT/image_rect_compressed: + index_column: /CAM_FRONT_LEFT/image_rect_compressed/idx reader: _target_: rbyte.io.frame.mcap.McapFrameReader path: data/NuScenes-v1.0-mini-scene-0103.mcap topic: /CAM_FRONT_LEFT/image_rect_compressed - message_decoder_factory: ${message_decoder_factory} - frame_decoder: ${jpeg_decoder} + decoder_factory: mcap_protobuf.decoder.DecoderFactory + frame_decoder: ${frame_decoder} + + /CAM_FRONT_RIGHT/image_rect_compressed: + index_column: /CAM_FRONT_RIGHT/image_rect_compressed/idx + reader: + _target_: rbyte.io.frame.mcap.McapFrameReader + path: data/NuScenes-v1.0-mini-scene-0103.mcap + topic: /CAM_FRONT_RIGHT/image_rect_compressed + decoder_factory: mcap_protobuf.decoder.DecoderFactory + frame_decoder: ${frame_decoder} table: path: data/NuScenes-v1.0-mini-scene-0103.mcap builder: _target_: rbyte.io.table.TableBuilder + _convert_: all reader: - _target_: rbyte.io.table.mcap.McapProtobufTableReader + _target_: rbyte.io.table.mcap.McapTableReader _recursive_: false - _convert_: all + decoder_factories: + - mcap_protobuf.decoder.DecoderFactory + - rbyte.utils.mcap.McapJsonDecoderFactory fields: /CAM_FRONT/image_rect_compressed: log_time: _target_: polars.Datetime time_unit: ns + idx: null /CAM_FRONT_LEFT/image_rect_compressed: log_time: _target_: polars.Datetime time_unit: ns + idx: null - /gps: + /CAM_FRONT_RIGHT/image_rect_compressed: log_time: _target_: polars.Datetime time_unit: ns + idx: null - latitude: polars.Float64 - longitude: polars.Float64 + /odom: + log_time: + _target_: polars.Datetime + time_unit: ns + vel.x: null merger: _target_: rbyte.io.table.TableMerger @@ -167,40 +185,43 @@ dataset: /CAM_FRONT_LEFT/image_rect_compressed: log_time: method: ref - frame_idx: + idx: method: asof - tolerance: 100ms + tolerance: 10ms strategy: nearest - /gps: + /CAM_FRONT_RIGHT/image_rect_compressed: log_time: method: ref - latitude: - method: asof - tolerance: 1000ms - strategy: nearest - longitude: + idx: method: asof - tolerance: 1000ms + tolerance: 10ms strategy: nearest + /odom: + log_time: + method: ref + vel.x: + method: interp + + filter: | + `/odom/vel.x` >= 8.6 + + cache: !!null + sample_builder: _target_: rbyte.sample.builder.GreedySampleTableBuilder - index_column: /CAM_FRONT/image_rect_compressed/frame_idx - length: 1 - stride: 1 - min_step: 1 + index_column: /CAM_FRONT/image_rect_compressed/idx -jpeg_decoder: +frame_decoder: _target_: simplejpeg.decode_jpeg _partial_: true colorspace: rgb - -message_decoder_factory: - _target_: mcap_protobuf.decoder.DecoderFactory + fastdct: true + fastupsample: true ``` -3. Build a dataloader and inspect a batch: +Build a dataloader and print a batch: ```python from omegaconf import OmegaConf from hydra.utils import instantiate @@ -212,62 +233,58 @@ batch = next(iter(dataloader)) print(batch) ``` +Inspect the batch: ```python Batch( frame=TensorDict( fields={ - CAM_BACK: Tensor(shape=torch.Size([32, 1, 900, 1600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), - CAM_FRONT: Tensor(shape=torch.Size([32, 1, 900, 1600, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, - batch_size=torch.Size([32]), + /CAM_FRONT/image_rect_compressed: Tensor(shape=torch.Size([1, 1, 900, 1600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + /CAM_FRONT_LEFT/image_rect_compressed: Tensor(shape=torch.Size([1, 1, 900, 1600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + /CAM_FRONT_RIGHT/image_rect_compressed: Tensor(shape=torch.Size([1, 1, 900, 1600, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([1]), device=None, is_shared=False), meta=BatchMeta( - input_id=NonTensorData(data=['NuScenes-v1.0-mini ... .0-mini-scene-0103'], batch_size=torch.Size([32]), device=None), - sample_idx=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), - batch_size=torch.Size([32]), + input_id=NonTensorData(data=['NuScenes-v1.0-mini-scene-0103'], batch_size=torch.Size([1]), device=None), + sample_idx=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([1]), device=None, is_shared=False), table=TensorDict( fields={ - /CAM_BACK/image_rect_compressed/frame_idx: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.int64, is_shared=False), - /CAM_BACK/image_rect_compressed/log_time: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.int64, is_shared=False), - /CAM_FRONT/image_rect_compressed/frame_idx: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.int64, is_shared=False), - /CAM_FRONT/image_rect_compressed/log_time: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.int64, is_shared=False), - /gps/latitude: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False), - /gps/log_time: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.int64, is_shared=False), - /gps/longitude: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False)}, - batch_size=torch.Size([32]), + /CAM_FRONT/image_rect_compressed/idx: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /CAM_FRONT/image_rect_compressed/log_time: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /CAM_FRONT_LEFT/image_rect_compressed/idx: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /CAM_FRONT_LEFT/image_rect_compressed/log_time: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /CAM_FRONT_RIGHT/image_rect_compressed/idx: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /CAM_FRONT_RIGHT/image_rect_compressed/log_time: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + /odom/vel.x: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([1]), device=None, is_shared=False), - batch_size=torch.Size([32]), + batch_size=torch.Size([1]), device=None, is_shared=False) - ``` -
-(optional) rerun visualization - -4. Add a `logger` to `config.yaml`: +Append a `logger` to `config.yaml`: ```yaml -# ... - logger: _target_: rbyte.viz.loggers.RerunLogger schema: frame: - CAM_FRONT: rerun.components.ImageBufferBatch - CAM_FRONT_LEFT: rerun.components.ImageBufferBatch - + /CAM_FRONT/image_rect_compressed: rerun.components.ImageBufferBatch + /CAM_FRONT_LEFT/image_rect_compressed: rerun.components.ImageBufferBatch + /CAM_FRONT_RIGHT/image_rect_compressed: rerun.components.ImageBufferBatch table: /CAM_FRONT/image_rect_compressed/log_time: rerun.TimeNanosColumn - /CAM_FRONT_LEFT/image_rect_compressed/frame_idx: rerun.TimeSequenceColumn - /gps/log_time: rerun.TimeNanosColumn - /gps/latitude: rerun.components.ScalarBatch - /gps/longitude: rerun.components.ScalarBatch + /CAM_FRONT/image_rect_compressed/idx: rerun.TimeSequenceColumn + /CAM_FRONT_LEFT/image_rect_compressed/idx: rerun.TimeSequenceColumn + /CAM_FRONT_RIGHT/image_rect_compressed/idx: rerun.TimeSequenceColumn + /odom/vel.x: rerun.components.ScalarBatch ``` -5. Visualize the dataset: +Visualize the dataset: ```python from omegaconf import OmegaConf from hydra.utils import instantiate @@ -280,14 +297,10 @@ logger = instantiate(config.logger) for batch_idx, batch in enumerate(dataloader): logger.log(batch_idx, batch) ``` -image +rerun
-
- - - ## Development 1. Install required tools: diff --git a/examples/config_templates/dataset/mcap_protobuf.yaml b/examples/config_templates/dataset/mcap.yaml similarity index 55% rename from examples/config_templates/dataset/mcap_protobuf.yaml rename to examples/config_templates/dataset/mcap.yaml index f79b5e6..6631a66 100644 --- a/examples/config_templates/dataset/mcap_protobuf.yaml +++ b/examples/config_templates/dataset/mcap.yaml @@ -6,10 +6,10 @@ #@ 'NuScenes-v1.0-mini-scene-0103', #@ ] -#@ cameras = [ -#@ 'CAM_FRONT', -#@ 'CAM_FRONT_LEFT', -#@ 'CAM_FRONT_RIGHT', +#@ camera_topics = [ +#@ '/CAM_FRONT/image_rect_compressed', +#@ '/CAM_FRONT_LEFT/image_rect_compressed', +#@ '/CAM_FRONT_RIGHT/image_rect_compressed', #@ ] --- _target_: rbyte.Dataset @@ -19,15 +19,14 @@ inputs: #@ for input_id in inputs: (@=input_id@): frame: - #@ for source_id in cameras: - (@=source_id@): - index_column: /(@=source_id@)/image_rect_compressed/frame_idx + #@ for topic in camera_topics: + (@=topic@): + index_column: (@=topic@)/idx reader: _target_: rbyte.io.frame.mcap.McapFrameReader path: "${data_dir}/(@=input_id@).mcap" - topic: /(@=source_id@)/image_rect_compressed - message_decoder_factory: - _target_: mcap_protobuf.decoder.DecoderFactory + topic: (@=topic@) + decoder_factory: mcap_protobuf.decoder.DecoderFactory frame_decoder: _target_: simplejpeg.decode_jpeg _partial_: true @@ -40,62 +39,57 @@ inputs: path: "${data_dir}/(@=input_id@).mcap" builder: _target_: rbyte.io.table.TableBuilder + _convert_: all reader: - _target_: rbyte.io.table.mcap.McapProtobufTableReader + _target_: rbyte.io.table.mcap.McapTableReader _recursive_: false - _convert_: all + decoder_factories: + - mcap_protobuf.decoder.DecoderFactory + - rbyte.utils.mcap.McapJsonDecoderFactory + fields: - #@ for camera in cameras: - /(@=camera@)/image_rect_compressed: + #@ for topic in camera_topics: + (@=topic@): log_time: _target_: polars.Datetime time_unit: ns + + idx: #@ end - /gps: + /odom: log_time: _target_: polars.Datetime time_unit: ns - - latitude: polars.Float64 - longitude: polars.Float64 - altitude: polars.Float64 + vel.x: merger: _target_: rbyte.io.table.TableMerger separator: / merge: - /(@=cameras[0]@)/image_rect_compressed: + (@=camera_topics[0]@): log_time: method: ref - #@ for camera in cameras[1:]: - /(@=camera@)/image_rect_compressed: + #@ for topic in camera_topics[1:]: + (@=topic@): log_time: method: ref - frame_idx: + idx: method: asof tolerance: 10ms strategy: nearest #@ end - /gps: + /odom: log_time: method: ref - latitude: - method: asof - tolerance: 1000ms - strategy: nearest - longitude: - method: asof - tolerance: 1000ms - strategy: nearest - altitude: - method: asof - tolerance: 1000ms - strategy: nearest + vel.x: + method: interp + + filter: | + `/odom/vel.x` >= 8.6 - filter: !!null cache: _target_: rbyte.utils.dataframe.DataframeDiskCache directory: /tmp/rbyte-cache @@ -104,7 +98,7 @@ inputs: sample_builder: _target_: rbyte.sample.builder.GreedySampleTableBuilder - index_column: /(@=cameras[0]@)/image_rect_compressed/frame_idx + index_column: (@=camera_topics[0]@)/idx length: 1 stride: 1 min_step: 1 diff --git a/examples/config_templates/frame_reader/mcap.yaml b/examples/config_templates/frame_reader/mcap.yaml index 6e3e886..1dc73e5 100644 --- a/examples/config_templates/frame_reader/mcap.yaml +++ b/examples/config_templates/frame_reader/mcap.yaml @@ -3,9 +3,7 @@ _target_: rbyte.io.frame.mcap.McapFrameReader _recursive_: true path: ??? topic: ??? -message_decoder_factory: - _target_: mcap_protobuf.decoder.DecoderFactory - +decoder_factory: mcap_protobuf.decoder.DecoderFactory frame_decoder: _target_: simplejpeg.decode_jpeg _partial_: true diff --git a/examples/config_templates/logger/rerun/mcap.yaml b/examples/config_templates/logger/rerun/mcap.yaml new file mode 100644 index 0000000..1de6454 --- /dev/null +++ b/examples/config_templates/logger/rerun/mcap.yaml @@ -0,0 +1,21 @@ +#@yaml/text-templated-strings + +#@ camera_topics = [ +#@ '/CAM_FRONT/image_rect_compressed', +#@ '/CAM_FRONT_LEFT/image_rect_compressed', +#@ '/CAM_FRONT_RIGHT/image_rect_compressed', +#@ ] +--- +_target_: rbyte.viz.loggers.RerunLogger +schema: + frame: + #@ for topic in camera_topics: + (@=topic@): rerun.components.ImageBufferBatch + #@ end + + table: + (@=camera_topics[0]@)/log_time: rerun.TimeNanosColumn + #@ for topic in camera_topics: + (@=topic@)/idx: rerun.TimeSequenceColumn + #@ end + /odom/vel.x: rerun.components.ScalarBatch diff --git a/examples/config_templates/logger/rerun/mcap_protobuf.yaml b/examples/config_templates/logger/rerun/mcap_protobuf.yaml deleted file mode 100644 index 820391b..0000000 --- a/examples/config_templates/logger/rerun/mcap_protobuf.yaml +++ /dev/null @@ -1,24 +0,0 @@ -#@yaml/text-templated-strings - -#@ cameras = [ -#@ 'CAM_FRONT', -#@ 'CAM_FRONT_LEFT', -#@ 'CAM_FRONT_RIGHT', -#@ ] ---- -_target_: rbyte.viz.loggers.RerunLogger -schema: - frame: - #@ for camera in cameras: - (@=camera@): rerun.components.ImageBufferBatch - #@ end - - table: - /(@=cameras[0]@)/image_rect_compressed/log_time: rerun.TimeNanosColumn - #@ for camera in cameras: - /(@=camera@)/image_rect_compressed/frame_idx: rerun.TimeSequenceColumn - #@ end - /gps/log_time: rerun.TimeNanosColumn - /gps/latitude: rerun.components.ScalarBatch - /gps/longitude: rerun.components.ScalarBatch - /gps/altitude: rerun.components.ScalarBatch diff --git a/examples/config_templates/table_builder/mcap.yaml b/examples/config_templates/table_builder/mcap.yaml new file mode 100644 index 0000000..8bfc867 --- /dev/null +++ b/examples/config_templates/table_builder/mcap.yaml @@ -0,0 +1,61 @@ +#@yaml/text-templated-strings + +#@ camera_topics = [ +#@ '/CAM_FRONT_LEFT/image_rect_compressed', +#@ '/CAM_FRONT_RIGHT/image_rect_compressed', +#@ ] +--- +_target_: rbyte.io.table.TableBuilder +_convert_: all +reader: + _target_: rbyte.io.table.mcap.McapTableReader + _recursive_: false + decoder_factories: + - mcap_protobuf.decoder.DecoderFactory + - rbyte.utils.mcap.McapJsonDecoderFactory + + fields: + #@ for topic in camera_topics: + (@=topic@): + log_time: + _target_: polars.Datetime + time_unit: ns + + idx: + #@ end + + /odom: + log_time: + _target_: polars.Datetime + time_unit: ns + vel.x: + +merger: + _target_: rbyte.io.table.TableMerger + separator: / + merge: + (@=camera_topics[0]@): + log_time: + method: ref + + #@ for topic in camera_topics[1:]: + (@=topic@): + log_time: + method: ref + idx: + method: asof + tolerance: 10ms + strategy: nearest + #@ end + + /odom: + log_time: + method: ref + vel.x: + method: interp + +filter: !!null +cache: + _target_: rbyte.utils.dataframe.DataframeDiskCache + directory: /tmp/rbyte-cache + size_limit: 1GiB diff --git a/examples/config_templates/table_builder/mcap_protobuf.yaml b/examples/config_templates/table_builder/mcap_protobuf.yaml deleted file mode 100644 index d05d421..0000000 --- a/examples/config_templates/table_builder/mcap_protobuf.yaml +++ /dev/null @@ -1,69 +0,0 @@ -#@yaml/text-templated-strings - -#@ cameras = [ -#@ 'CAM_FRONT', -#@ 'CAM_FRONT_LEFT', -#@ 'CAM_FRONT_RIGHT', -#@ ] ---- -_target_: rbyte.io.table.TableBuilder -reader: - _target_: rbyte.io.table.mcap.McapProtobufTableReader - _recursive_: false - _convert_: all - fields: - #@ for camera in cameras: - /(@=camera@)/image_rect_compressed: - log_time: - _target_: polars.Datetime - time_unit: ns - #@ end - - /gps: - log_time: - _target_: polars.Datetime - time_unit: ns - - latitude: polars.Float64 - longitude: polars.Float64 - altitude: polars.Float64 - -merger: - _target_: rbyte.io.table.TableMerger - separator: / - merge: - /(@=cameras[0]@)/image_rect_compressed: - log_time: - method: ref - - #@ for camera in cameras[1:]: - /(@=camera@)/image_rect_compressed: - log_time: - method: ref - frame_idx: - method: asof - tolerance: 10ms - strategy: nearest - #@ end - - /gps: - log_time: - method: ref - latitude: - method: asof - tolerance: 1000ms - strategy: nearest - longitude: - method: asof - tolerance: 1000ms - strategy: nearest - altitude: - method: asof - tolerance: 1000ms - strategy: nearest - -filter: !!null -cache: - _target_: rbyte.utils.dataframe.DataframeDiskCache - directory: /tmp/rbyte-cache - size_limit: 1GiB diff --git a/examples/config_templates/table_builder/mcap_ros2.yaml b/examples/config_templates/table_builder/mcap_ros2.yaml deleted file mode 100644 index 3137b76..0000000 --- a/examples/config_templates/table_builder/mcap_ros2.yaml +++ /dev/null @@ -1,62 +0,0 @@ -#! https://jkk-research.github.io/dataset/jkk_dataset_02 - -#@yaml/text-templated-strings ---- -_target_: rbyte.io.table.TableBuilder -reader: - _target_: rbyte.io.table.mcap.McapRos2TableReader - _recursive_: false - _convert_: all - fields: - /nissan/vehicle_speed: - log_time: - _target_: polars.Datetime - time_unit: ns - - data: - - /nissan/gps/duro/current_pose: - log_time: - _target_: polars.Datetime - time_unit: ns - - pose.position.x: - pose.position.y: - pose.position.z: - pose.orientation.x: - pose.orientation.y: - pose.orientation.z: - pose.orientation.w: - -merger: - _target_: rbyte.io.table.TableMerger - separator: / - merge: - /nissan/vehicle_speed: - log_time: - method: ref - - /nissan/gps/duro/current_pose: - log_time: - method: ref - pose.position.x: - method: interp - pose.position.y: - method: interp - pose.position.z: - method: interp - pose.orientation.x: - method: interp - pose.orientation.y: - method: interp - pose.orientation.z: - method: interp - pose.orientation.w: - method: interp - -filter: !!null - -cache: - _target_: rbyte.utils.dataframe.DataframeDiskCache - directory: /tmp/rbyte-cache - size_limit: 1GiB diff --git a/pyproject.toml b/pyproject.toml index 6a80ff5..cacc921 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.2.0" +version = "0.3.0" description = "Multimodal dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] @@ -40,9 +40,7 @@ mcap = [ "mcap>=1.1.1", "mcap-protobuf-support>=0.5.1", "mcap-ros2-support>=0.5.3", - "protobuf", - "ptars>=0.0.2rc2", - "pyarrow-stubs>=17.6", + "python-box>=7.2.0", ] yaak = ["protobuf", "ptars>=0.0.2rc2"] jpeg = ["simplejpeg>=1.7.6"] diff --git a/src/rbyte/dataset.py b/src/rbyte/dataset.py index 8c1ddac..42b3bdf 100644 --- a/src/rbyte/dataset.py +++ b/src/rbyte/dataset.py @@ -126,9 +126,9 @@ def _build_table(cls, sources: SourcesConfig) -> pl.LazyFrame: logger.debug("building table") match sources: - case SourcesConfig(frame=frame_sources, table=None) if len( - frame_sources - ) == 1: + case SourcesConfig(frame=frame_sources, table=None) if ( + len(frame_sources) == 1 + ): frame_source = mit.one(frame_sources.values()) frame_reader = frame_source.reader.instantiate() frame_idxs = pl.Series( diff --git a/src/rbyte/io/frame/mcap/reader.py b/src/rbyte/io/frame/mcap/reader.py index 54827a5..57e51f6 100644 --- a/src/rbyte/io/frame/mcap/reader.py +++ b/src/rbyte/io/frame/mcap/reader.py @@ -14,7 +14,7 @@ from mcap.reader import SeekingReader from mcap.records import Chunk, ChunkIndex, Message from mcap.stream_reader import get_chunk_data_stream -from pydantic import ConfigDict, FilePath, validate_call +from pydantic import ConfigDict, FilePath, ImportString, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars from torch import Tensor @@ -37,16 +37,14 @@ def __init__( self, path: FilePath, topic: str, - message_decoder_factory: DecoderFactory, + decoder_factory: ImportString[type[DecoderFactory]], frame_decoder: Callable[[bytes], npt.ArrayLike], validate_crcs: bool = False, # noqa: FBT001, FBT002 ) -> None: super().__init__() with bound_contextvars( - path=path.as_posix(), - topic=topic, - message_decoder_factory=message_decoder_factory, + path=path.as_posix(), topic=topic, message_decoder_factory=decoder_factory ): self._path = path self._validate_crcs = validate_crcs @@ -65,7 +63,7 @@ def __init__( if channel.topic == topic ) - message_decoder = message_decoder_factory.decoder_for( + message_decoder = decoder_factory().decoder_for( message_encoding=self._channel.message_encoding, schema=summary.schemas[self._channel.schema_id], ) diff --git a/src/rbyte/io/table/mcap/__init__.py b/src/rbyte/io/table/mcap/__init__.py index e95590c..fef5bc6 100644 --- a/src/rbyte/io/table/mcap/__init__.py +++ b/src/rbyte/io/table/mcap/__init__.py @@ -1,15 +1,3 @@ -__all__: list[str] = [] +from .reader import McapTableReader -try: - from .protobuf_reader import McapProtobufTableReader -except ImportError: - pass -else: - __all__ += ["McapProtobufTableReader"] - -try: - from .ros2_reader import McapRos2TableReader -except ImportError: - pass -else: - __all__ += ["McapRos2TableReader"] +__all__ = ["McapTableReader"] diff --git a/src/rbyte/io/table/mcap/config.py b/src/rbyte/io/table/mcap/config.py deleted file mode 100644 index b99593f..0000000 --- a/src/rbyte/io/table/mcap/config.py +++ /dev/null @@ -1,19 +0,0 @@ -from collections.abc import Mapping - -import polars as pl -from pydantic import ConfigDict, ImportString - -from rbyte.config.base import BaseModel, HydraConfig - -PolarsDataType = pl.DataType | pl.DataTypeClass - - -class Config(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - fields: Mapping[ - str, - Mapping[str, HydraConfig[PolarsDataType] | ImportString[PolarsDataType] | None], - ] - - validate_crcs: bool = False diff --git a/src/rbyte/io/table/mcap/protobuf_reader.py b/src/rbyte/io/table/mcap/protobuf_reader.py deleted file mode 100644 index 54ad8f9..0000000 --- a/src/rbyte/io/table/mcap/protobuf_reader.py +++ /dev/null @@ -1,173 +0,0 @@ -import json -from collections.abc import Hashable, Mapping -from functools import cached_property -from mmap import ACCESS_READ, mmap -from os import PathLike -from pathlib import Path -from typing import Literal, cast, override - -import more_itertools as mit -import polars as pl -import polars._typing as plt -import pyarrow as pa -from cachetools import cached -from google.protobuf.descriptor_pb2 import FileDescriptorProto, FileDescriptorSet -from google.protobuf.descriptor_pool import DescriptorPool -from google.protobuf.message import Message -from google.protobuf.message_factory import GetMessageClassesForFiles -from mcap.reader import SeekingReader -from mcap.records import Schema -from ptars import HandlerPool -from structlog import get_logger -from structlog.contextvars import bound_contextvars -from tqdm import tqdm -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config.base import HydraConfig -from rbyte.io.table.base import TableReaderBase -from rbyte.utils.dataframe import unnest_all - -from .config import Config - -logger = get_logger(__name__) - - -class McapProtobufTableReader(TableReaderBase, Hashable): - FRAME_INDEX_COLUMN_NAME: Literal["frame_idx"] = "frame_idx" - - def __init__(self, **kwargs: object) -> None: - self._config = Config.model_validate(kwargs) - - @override - def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: - with ( - bound_contextvars(path=str(path)), - Path(path).open("rb") as _f, - mmap(fileno=_f.fileno(), length=0, access=ACCESS_READ) as f, - ): - reader = SeekingReader(f, validate_crcs=self._config.validate_crcs) # pyright: ignore[reportArgumentType] - summary = reader.get_summary() - if summary is None: - logger.error(msg := "missing summary") - raise ValueError(msg) - - topics = self.schemas.keys() - if missing_topics := topics - ( - available_topics := {ch.topic for ch in summary.channels.values()} - ): - with bound_contextvars( - missing_topics=sorted(missing_topics), - available_topics=sorted(available_topics), - ): - logger.error(msg := "missing topics") - raise ValueError(msg) - - schemas = { - channel.topic: summary.schemas[channel.schema_id] - for channel in summary.channels.values() - if channel.topic in topics - } - - message_counts = ( - { - channel.topic: stats.channel_message_counts[channel_id] - for channel_id, channel in summary.channels.items() - if channel.topic in topics - } - if (stats := summary.statistics) is not None - else {} - ) - - messages = mit.bucket( - reader.iter_messages(topics), - key=lambda x: x[1].topic, - validator=lambda k: k in topics, - ) - - dfs: Mapping[str, pl.DataFrame] = {} - handler_pool = HandlerPool() - - for topic, schema in self.schemas.items(): - log_time, publish_time, data = mit.unzip( - (msg.log_time, msg.publish_time, msg.data) - for (*_, msg) in tqdm( - messages[topic], - total=message_counts[topic], - postfix={"topic": topic}, - ) - ) - - msg_type = self._get_message_type(schemas[topic]) - handler = handler_pool.get_for_message(msg_type.DESCRIPTOR) # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - - record_batch = cast( - pa.RecordBatch, - handler.list_to_record_batch(list(data)), # pyright: ignore[reportUnknownMemberType] - ) - df = cast(pl.DataFrame, pl.from_arrow(record_batch)) # pyright: ignore[reportUnknownMemberType] - - dfs[topic] = ( - (df.select(unnest_all(df.collect_schema()))) - .hstack([ - pl.Series("log_time", log_time, pl.UInt64), - pl.Series("publish_time", publish_time, pl.UInt64), - ]) - .select(schema) - .cast({k: v for k, v in schema.items() if v is not None}) - ) - - for topic, df in dfs.items(): - match schemas[topic]: - case Schema( - encoding="protobuf", - name="foxglove.CompressedImage" - | "foxglove.CompressedVideo" - | "foxglove.RawImage", - ): - dfs[topic] = df.with_row_index(self.FRAME_INDEX_COLUMN_NAME) - - case _: - pass - - return dfs - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) - - @cached(cache={}, key=lambda _, schema: schema.name) # pyright: ignore[reportUnknownLambdaType, reportUnknownArgumentType, reportUnknownMemberType] - def _get_message_type(self, schema: Schema) -> type[Message]: # noqa: PLR6301 - # inspired by https://github.com/foxglove/mcap/blob/e591defaa95186cef27e37c49fa7e1f0c9f2e8a6/python/mcap-protobuf-support/mcap_protobuf/decoder.py#L29 - fds = FileDescriptorSet.FromString(schema.data) - descriptors = {fd.name: fd for fd in fds.file} - pool = DescriptorPool() - - def _add(fd: FileDescriptorProto) -> None: - for dependency in fd.dependency: - if dependency in descriptors: - _add(descriptors.pop(dependency)) - - pool.Add(fd) # pyright: ignore[reportUnknownMemberType] - - while descriptors: - _add(descriptors.popitem()[1]) - - messages = GetMessageClassesForFiles([fd.name for fd in fds.file], pool) - - return mit.one( - msg_type for name, msg_type in messages.items() if name == schema.name - ) - - @cached_property - def schemas(self) -> dict[str, dict[str, plt.PolarsDataType | None]]: - return { - topic: { - path: leaf.instantiate() if isinstance(leaf, HydraConfig) else leaf - for path, leaf in fields.items() - } - for topic, fields in self._config.fields.items() - } diff --git a/src/rbyte/io/table/mcap/reader.py b/src/rbyte/io/table/mcap/reader.py new file mode 100644 index 0000000..f6f97e9 --- /dev/null +++ b/src/rbyte/io/table/mcap/reader.py @@ -0,0 +1,176 @@ +import json +from collections.abc import Hashable, Iterable, Mapping, Sequence +from enum import StrEnum, unique +from functools import cached_property +from mmap import ACCESS_READ, mmap +from operator import attrgetter +from os import PathLike +from pathlib import Path +from typing import Any, NamedTuple, override + +import more_itertools as mit +import polars as pl +from mcap.decoder import DecoderFactory +from mcap.reader import DecodedMessageTuple, SeekingReader +from pydantic import ( + ConfigDict, + ImportString, + SerializationInfo, + SerializerFunctionWrapHandler, + field_serializer, +) +from structlog import get_logger +from structlog.contextvars import bound_contextvars +from tqdm import tqdm +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config.base import BaseModel, HydraConfig +from rbyte.io.table.base import TableReaderBase + +logger = get_logger(__name__) + + +PolarsDataType = pl.DataType | pl.DataTypeClass + + +class Config(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + decoder_factories: Sequence[ImportString[type[DecoderFactory]]] + + fields: Mapping[ + str, + Mapping[str, HydraConfig[PolarsDataType] | ImportString[PolarsDataType] | None], + ] + + validate_crcs: bool = False + + @field_serializer("decoder_factories", when_used="json", mode="wrap") + @staticmethod + def serialize_decoder_factories( + value: frozenset[ImportString[type[DecoderFactory]]], + nxt: SerializerFunctionWrapHandler, + _info: SerializationInfo, + ) -> Sequence[str]: + return sorted(nxt(value)) + + +class RowValues(NamedTuple): + topic: str + values: Iterable[Any] + + +@unique +class SpecialFields(StrEnum): + log_time = "log_time" + publish_time = "publish_time" + idx = "idx" + + +class McapTableReader(TableReaderBase, Hashable): + def __init__(self, **kwargs: object) -> None: + self._config = Config.model_validate(kwargs) + + @override + def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + with ( + bound_contextvars(path=str(path)), + Path(path).open("rb") as _f, + mmap(fileno=_f.fileno(), length=0, access=ACCESS_READ) as f, + ): + reader = SeekingReader( + f, # pyright: ignore[reportArgumentType] + validate_crcs=self._config.validate_crcs, + decoder_factories=[f() for f in self._config.decoder_factories], + ) + summary = reader.get_summary() + if summary is None: + logger.error(msg := "missing summary") + raise ValueError(msg) + + topics = self.schemas.keys() + if missing_topics := topics - ( + available_topics := {ch.topic for ch in summary.channels.values()} + ): + with bound_contextvars( + missing_topics=sorted(missing_topics), + available_topics=sorted(available_topics), + ): + logger.error(msg := "missing topics") + raise ValueError(msg) + + message_count = ( + sum( + stats.channel_message_counts[channel.id] + for channel in summary.channels.values() + if channel.topic in topics + ) + if (stats := summary.statistics) is not None + else None + ) + + row_values = ( + RowValues( + dmt.channel.topic, + self._get_values(dmt, self.schemas[dmt.channel.topic]), + ) + for dmt in tqdm( + reader.iter_decoded_messages(topics), + desc="messages", + total=message_count, + ) + ) + + row_values_by_topic = mit.bucket(row_values, key=lambda rd: rd.topic) + + dfs: Mapping[str, pl.DataFrame] = {} + for topic, schema in self.schemas.items(): + df_schema = {k: v for k, v in schema.items() if k != SpecialFields.idx} + df = pl.DataFrame( + data=(tuple(x.values) for x in row_values_by_topic[topic]), + schema=df_schema, # pyright: ignore[reportArgumentType] + orient="row", + ) + + if (idx_name := SpecialFields.idx) in schema: + df = df.with_row_index(idx_name).cast({ + idx_name: schema[idx_name] or pl.UInt32 + }) + + dfs[topic] = df + + return dfs + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) + + @cached_property + def schemas(self) -> dict[str, dict[str, PolarsDataType | None]]: + return { + topic: { + path: leaf.instantiate() if isinstance(leaf, HydraConfig) else leaf + for path, leaf in fields.items() + } + for topic, fields in self._config.fields.items() + } + + @staticmethod + def _get_values(dmt: DecodedMessageTuple, fields: Iterable[str]) -> Iterable[Any]: + for field in fields: + match field: + case SpecialFields.log_time: + yield dmt.message.log_time + + case SpecialFields.publish_time: + yield dmt.message.publish_time + + case SpecialFields.idx: + pass # added later + + case _: + yield attrgetter(field)(dmt.decoded_message) diff --git a/src/rbyte/io/table/mcap/ros2_reader.py b/src/rbyte/io/table/mcap/ros2_reader.py deleted file mode 100644 index 2bb5650..0000000 --- a/src/rbyte/io/table/mcap/ros2_reader.py +++ /dev/null @@ -1,105 +0,0 @@ -import json -from collections.abc import Hashable, Mapping -from functools import cached_property -from mmap import ACCESS_READ, mmap -from operator import attrgetter -from os import PathLike -from pathlib import Path -from typing import Any, Literal, override - -import polars as pl -import polars._typing as plt -from mcap.reader import SeekingReader -from mcap_ros2.decoder import DecoderFactory -from structlog import get_logger -from structlog.contextvars import bound_contextvars -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config.base import HydraConfig -from rbyte.io.table.base import TableReaderBase - -from .config import Config - -logger = get_logger(__name__) - - -class McapRos2TableReader(TableReaderBase, Hashable): - FRAME_INDEX_COLUMN_NAME: Literal["frame_idx"] = "frame_idx" - - def __init__(self, **kwargs: object) -> None: - self._config = Config.model_validate(kwargs) - - @override - def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: - with ( - bound_contextvars(path=str(path)), - Path(path).open("rb") as _f, - mmap(fileno=_f.fileno(), length=0, access=ACCESS_READ) as f, - ): - reader = SeekingReader( - f, # pyright: ignore[reportArgumentType] - validate_crcs=self._config.validate_crcs, - decoder_factories=[DecoderFactory()], - ) - summary = reader.get_summary() - if summary is None: - logger.error(msg := "missing summary") - raise ValueError(msg) - - topics = self._config.fields.keys() - if missing_topics := topics - ( - available_topics := {ch.topic for ch in summary.channels.values()} - ): - with bound_contextvars( - missing_topics=sorted(missing_topics), - available_topics=sorted(available_topics), - ): - logger.error(msg := "missing topics") - raise ValueError(msg) - - getters = { - topic: [attrgetter(field) for field in fields] - for topic, fields in self.schemas.items() - } - - rows: Mapping[str, list[list[Any]]] = {topic: [] for topic in self.schemas} - - for _, channel, msg, msg_decoded in reader.iter_decoded_messages(topics): - topic = channel.topic - row: list[Any] = [] - for getter in getters[topic]: - try: - attr = getter(msg_decoded) - except AttributeError: - attr = getter(msg) - - row.append(attr) - - rows[topic].append(row) - - return { - topic: pl.DataFrame( - data=rows[topic], - schema=self.schemas[topic], # pyright: ignore[reportArgumentType] - orient="row", - ) - for topic in topics - } - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) - - @cached_property - def schemas(self) -> dict[str, dict[str, plt.PolarsDataType | None]]: - return { - topic: { - path: leaf.instantiate() if isinstance(leaf, HydraConfig) else leaf - for path, leaf in fields.items() - } - for topic, fields in self._config.fields.items() - } diff --git a/src/rbyte/utils/mcap/__init__.py b/src/rbyte/utils/mcap/__init__.py new file mode 100644 index 0000000..8a476ed --- /dev/null +++ b/src/rbyte/utils/mcap/__init__.py @@ -0,0 +1,3 @@ +from .json_decoder_factory import McapJsonDecoderFactory + +__all__ = ["McapJsonDecoderFactory"] diff --git a/src/rbyte/utils/mcap/json_decoder_factory.py b/src/rbyte/utils/mcap/json_decoder_factory.py new file mode 100644 index 0000000..51f093b --- /dev/null +++ b/src/rbyte/utils/mcap/json_decoder_factory.py @@ -0,0 +1,27 @@ +import json +from collections.abc import Callable +from typing import override + +from box import Box +from mcap.decoder import DecoderFactory as McapDecoderFactory +from mcap.records import Schema +from structlog import get_logger + +logger = get_logger(__name__) + + +class McapJsonDecoderFactory(McapDecoderFactory): + @override + def decoder_for( + self, message_encoding: str, schema: Schema | None + ) -> Callable[[bytes], Box] | None: + match message_encoding, getattr(schema, "encoding", None): + case "json", "jsonschema": + return self._decoder + + case _: + return None + + @staticmethod + def _decoder(data: bytes) -> Box: + return Box(json.loads(data))