From 9a1e5943e7e4af4a5f8ea9c9cbf8e939be3b2fa8 Mon Sep 17 00:00:00 2001 From: Evgenii Gorchakov Date: Mon, 23 Sep 2024 18:55:22 +0200 Subject: [PATCH] feat: mcap support (#7) - add mcap readers (protobuf/ros2) - abstract tabular source processing - slimmer configs - add `read-frames` script closes #1 --- .github/workflows/build.yaml | 2 +- .github/workflows/ci.yaml | 18 +- .pre-commit-config.yaml | 6 +- README.md | 2 +- examples/config_templates/build_table.yaml | 15 + examples/config_templates/dataset/carla.yaml | 93 ++-- .../dataset/mcap_protobuf.yaml | 111 +++++ examples/config_templates/dataset/yaak.yaml | 231 +++++----- .../frame_reader/directory.yaml | 10 + .../config_templates/frame_reader/mcap.yaml | 16 + .../config_templates/logger/rerun/carla.yaml | 23 + .../logger/rerun/mcap_protobuf.yaml | 24 + .../config_templates/logger/rerun/yaak.yaml | 23 + .../config_templates/logger/rerun_carla.yaml | 16 - .../config_templates/logger/rerun_yaak.yaml | 41 -- examples/config_templates/read_frames.yaml | 12 + examples/config_templates/table.yaml | 85 ---- .../config_templates/table_builder/carla.yaml | 18 + .../table_builder/mcap_protobuf.yaml | 69 +++ .../table_builder/mcap_ros2.yaml | 62 +++ .../config_templates/table_builder/yaak.yaml | 77 ++++ examples/config_templates/visualize.yaml | 11 +- justfile | 20 +- pyproject.toml | 54 ++- src/rbyte/__init__.py | 8 +- src/rbyte/batch/batch.py | 8 +- src/rbyte/config/base.py | 14 +- src/rbyte/{dataset => }/dataset.py | 138 +++--- src/rbyte/dataset/__init__.py | 4 - src/rbyte/dataset/config.py | 71 --- src/rbyte/io/frame/__init__.py | 3 + src/rbyte/io/frame/base.py | 4 +- src/rbyte/io/frame/directory/__init__.py | 3 + src/rbyte/io/frame/directory/reader.py | 47 ++ src/rbyte/io/frame/jpeg/__init__.py | 3 - src/rbyte/io/frame/jpeg/reader.py | 45 -- src/rbyte/io/frame/mcap/__init__.py | 3 + src/rbyte/io/frame/mcap/reader.py | 169 ++++++++ src/rbyte/io/table/__init__.py | 4 + src/rbyte/io/table/base.py | 24 +- src/rbyte/io/table/builder.py | 78 ++++ src/rbyte/io/table/carla/builder.py | 48 +- src/rbyte/io/table/mcap/__init__.py | 15 + src/rbyte/io/table/mcap/config.py | 19 + src/rbyte/io/table/mcap/protobuf_reader.py | 173 ++++++++ src/rbyte/io/table/mcap/ros2_reader.py | 105 +++++ src/rbyte/io/table/merger.py | 160 +++++++ .../io/table/transforms/fps_resampler.py | 11 +- src/rbyte/io/table/yaak/__init__.py | 4 +- src/rbyte/io/table/yaak/builder.py | 410 ------------------ src/rbyte/io/table/yaak/idl-repo | 2 +- src/rbyte/io/table/yaak/message_iterator.py | 28 +- src/rbyte/io/table/yaak/reader.py | 99 +++++ src/rbyte/sample/builder.py | 48 +- src/rbyte/scripts/build_table.py | 20 +- src/rbyte/scripts/read_frames.py | 61 +++ src/rbyte/scripts/visualize.py | 11 +- src/rbyte/utils/__init__.py | 3 - src/rbyte/utils/dataframe/__init__.py | 4 + .../cache.py} | 16 +- .../utils/{dataframe.py => dataframe/misc.py} | 9 +- src/rbyte/viz/loggers/rerun_logger.py | 93 ++-- 62 files changed, 1896 insertions(+), 1108 deletions(-) create mode 100644 examples/config_templates/build_table.yaml create mode 100644 examples/config_templates/dataset/mcap_protobuf.yaml create mode 100644 examples/config_templates/frame_reader/directory.yaml create mode 100644 examples/config_templates/frame_reader/mcap.yaml create mode 100644 examples/config_templates/logger/rerun/carla.yaml create mode 100644 examples/config_templates/logger/rerun/mcap_protobuf.yaml create mode 100644 examples/config_templates/logger/rerun/yaak.yaml delete mode 100644 examples/config_templates/logger/rerun_carla.yaml delete mode 100644 examples/config_templates/logger/rerun_yaak.yaml create mode 100644 examples/config_templates/read_frames.yaml delete mode 100644 examples/config_templates/table.yaml create mode 100644 examples/config_templates/table_builder/carla.yaml create mode 100644 examples/config_templates/table_builder/mcap_protobuf.yaml create mode 100644 examples/config_templates/table_builder/mcap_ros2.yaml create mode 100644 examples/config_templates/table_builder/yaak.yaml rename src/rbyte/{dataset => }/dataset.py (52%) delete mode 100644 src/rbyte/dataset/__init__.py delete mode 100644 src/rbyte/dataset/config.py create mode 100644 src/rbyte/io/frame/directory/__init__.py create mode 100644 src/rbyte/io/frame/directory/reader.py delete mode 100644 src/rbyte/io/frame/jpeg/__init__.py delete mode 100644 src/rbyte/io/frame/jpeg/reader.py create mode 100644 src/rbyte/io/frame/mcap/__init__.py create mode 100644 src/rbyte/io/frame/mcap/reader.py create mode 100644 src/rbyte/io/table/builder.py create mode 100644 src/rbyte/io/table/mcap/__init__.py create mode 100644 src/rbyte/io/table/mcap/config.py create mode 100644 src/rbyte/io/table/mcap/protobuf_reader.py create mode 100644 src/rbyte/io/table/mcap/ros2_reader.py create mode 100644 src/rbyte/io/table/merger.py delete mode 100644 src/rbyte/io/table/yaak/builder.py create mode 100644 src/rbyte/io/table/yaak/reader.py create mode 100644 src/rbyte/scripts/read_frames.py create mode 100644 src/rbyte/utils/dataframe/__init__.py rename src/rbyte/utils/{dataframe_cache.py => dataframe/cache.py} (73%) rename src/rbyte/utils/{dataframe.py => dataframe/misc.py} (66%) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 18f1692..f4cb227 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -26,4 +26,4 @@ jobs: submodules: recursive persist-credentials: false - - uses: hynek/build-and-inspect-python-package@v2.8.0 + - uses: hynek/build-and-inspect-python-package@v2.9.0 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index df6bed4..f99ec58 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,8 +12,6 @@ on: jobs: pre-commit: runs-on: ubuntu-latest - env: - UV_CACHE_DIR: /tmp/.uv-cache steps: - name: setup ssh @@ -33,16 +31,11 @@ jobs: python-version-file: "pyproject.toml" - name: install uv - run: curl -LsSf https://astral.sh/uv/install.sh | sh - - - name: restore uv cache - uses: actions/cache@v4 + uses: astral-sh/setup-uv@v3 with: - path: /tmp/.uv-cache - key: uv-${{ runner.os }}-${{ hashFiles('uv.lock') }} - restore-keys: | - uv-${{ runner.os }}-${{ hashFiles('uv.lock') }} - uv-${{ runner.os }} + version: "latest" + enable-cache: true + cache-dependency-glob: "**/pyproject.toml" - name: uv sync run: uv sync --all-extras --dev @@ -63,6 +56,3 @@ jobs: run: uvx pre-commit run --all-files --color=always env: SKIP: just-format,generate-example-config - - - name: uv cache prune - run: uv cache prune --ci diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aeba157..b2221ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: validate-pyproject - repo: https://github.com/crate-ci/typos - rev: v1.23.7 + rev: v1.24.6 hooks: - id: typos @@ -24,14 +24,14 @@ repos: exclude: examples/config - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.2 + rev: v0.6.7 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.17.0 + rev: 1.17.5 hooks: - id: basedpyright diff --git a/README.md b/README.md index 16810a0..9d13e48 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Multimodal dataset library. ## Installation ```bash -uv add https://github.com/yaak-ai/rbyte/releases/latest/download/rbyte-X.Y.Z-py3-none-any.whl [--extra visualize] +uv add https://github.com/yaak-ai/rbyte/releases/latest/download/rbyte-X.Y.Z-py3-none-any.whl [--extra visualize] [--extra jpeg] [--extra mcap] ``` ## Usage diff --git a/examples/config_templates/build_table.yaml b/examples/config_templates/build_table.yaml new file mode 100644 index 0000000..4c743ff --- /dev/null +++ b/examples/config_templates/build_table.yaml @@ -0,0 +1,15 @@ +--- +defaults: + - table_builder: !!null + - _self_ + +path: ??? +writer: + _target_: polars.DataFrame.write_csv + _partial_: true + file: ??? + +hydra: + output_subdir: !!null + run: + dir: . diff --git a/examples/config_templates/dataset/carla.yaml b/examples/config_templates/dataset/carla.yaml index 341a205..60dac24 100644 --- a/examples/config_templates/dataset/carla.yaml +++ b/examples/config_templates/dataset/carla.yaml @@ -1,44 +1,61 @@ +#@yaml/text-templated-strings + +#@ drives = [ +#@ 'carla_id1', +#@ ] + +#@ cameras = [ +#@ 'cam_front_left', +#@ ] --- _target_: rbyte.Dataset _recursive_: false -config: - inputs: - - id: carla_id1 - sources: - frame: - - id: cam_front_left - index_column: frame_idx - reader: - _target_: rbyte.io.frame.jpeg.JpegFrameReader - path: "${data_dir}/${.....id}/frames/${..id}.defish.mp4/576x324/{:09d}.jpg" - - table: - path: ${data_dir}/${...id}/ego_logs.json - builder: - _target_: rbyte.io.table.carla.CarlaRecordsTableBuilder - _recursive_: false - config: - index_column: frame_idx - select: - - control.brake - - control.throttle - - control.steer - - state.velocity.value - - state.acceleration.value +_convert_: all +inputs: + #@ for input_id in drives: + (@=input_id@): + frame: + #@ for source_id in cameras: + (@=source_id@): + index_column: frame_idx + reader: + _target_: rbyte.io.frame.DirectoryFrameReader + path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg" + frame_decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true - filter: | - `control.throttle` > 0.5 + #@ end - transforms: - - _target_: rbyte.io.table.transforms.FpsResampler - source_fps: 20 - target_fps: 30 - samples: - builder: - _target_: rbyte.sample.builder.GreedySampleTableBuilder - config: + table: + path: ${data_dir}/(@=input_id@)/ego_logs.json + builder: + _target_: rbyte.io.table.carla.CarlaRecordsTableBuilder + _convert_: all index_column: frame_idx - length: 1 - stride: 1 - min_step: 1 - filter: !!null + select: + - control.brake + - control.throttle + - control.steer + - state.velocity.value + - state.acceleration.value + + filter: | + `control.throttle` > 0.5 + + transforms: + - _target_: rbyte.io.table.transforms.FpsResampler + source_fps: 20 + target_fps: 30 + #@ end + +sample_builder: + _target_: rbyte.sample.builder.GreedySampleTableBuilder + index_column: frame_idx + length: 1 + stride: 1 + min_step: 1 + filter: !!null diff --git a/examples/config_templates/dataset/mcap_protobuf.yaml b/examples/config_templates/dataset/mcap_protobuf.yaml new file mode 100644 index 0000000..f79b5e6 --- /dev/null +++ b/examples/config_templates/dataset/mcap_protobuf.yaml @@ -0,0 +1,111 @@ +#! https://github.com/foxglove/nuscenes2mcap + +#@yaml/text-templated-strings + +#@ inputs = [ +#@ 'NuScenes-v1.0-mini-scene-0103', +#@ ] + +#@ cameras = [ +#@ 'CAM_FRONT', +#@ 'CAM_FRONT_LEFT', +#@ 'CAM_FRONT_RIGHT', +#@ ] +--- +_target_: rbyte.Dataset +_convert_: all +_recursive_: false +inputs: + #@ for input_id in inputs: + (@=input_id@): + frame: + #@ for source_id in cameras: + (@=source_id@): + index_column: /(@=source_id@)/image_rect_compressed/frame_idx + reader: + _target_: rbyte.io.frame.mcap.McapFrameReader + path: "${data_dir}/(@=input_id@).mcap" + topic: /(@=source_id@)/image_rect_compressed + message_decoder_factory: + _target_: mcap_protobuf.decoder.DecoderFactory + frame_decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + #@ end + + table: + path: "${data_dir}/(@=input_id@).mcap" + builder: + _target_: rbyte.io.table.TableBuilder + reader: + _target_: rbyte.io.table.mcap.McapProtobufTableReader + _recursive_: false + _convert_: all + fields: + #@ for camera in cameras: + /(@=camera@)/image_rect_compressed: + log_time: + _target_: polars.Datetime + time_unit: ns + #@ end + + /gps: + log_time: + _target_: polars.Datetime + time_unit: ns + + latitude: polars.Float64 + longitude: polars.Float64 + altitude: polars.Float64 + + merger: + _target_: rbyte.io.table.TableMerger + separator: / + merge: + /(@=cameras[0]@)/image_rect_compressed: + log_time: + method: ref + + #@ for camera in cameras[1:]: + /(@=camera@)/image_rect_compressed: + log_time: + method: ref + frame_idx: + method: asof + tolerance: 10ms + strategy: nearest + #@ end + + /gps: + log_time: + method: ref + latitude: + method: asof + tolerance: 1000ms + strategy: nearest + longitude: + method: asof + tolerance: 1000ms + strategy: nearest + altitude: + method: asof + tolerance: 1000ms + strategy: nearest + + filter: !!null + cache: + _target_: rbyte.utils.dataframe.DataframeDiskCache + directory: /tmp/rbyte-cache + size_limit: 1GiB + #@ end + +sample_builder: + _target_: rbyte.sample.builder.GreedySampleTableBuilder + index_column: /(@=cameras[0]@)/image_rect_compressed/frame_idx + length: 1 + stride: 1 + min_step: 1 + filter: !!null diff --git a/examples/config_templates/dataset/yaak.yaml b/examples/config_templates/dataset/yaak.yaml index db46453..8e3b953 100644 --- a/examples/config_templates/dataset/yaak.yaml +++ b/examples/config_templates/dataset/yaak.yaml @@ -1,130 +1,115 @@ ---- +#@yaml/text-templated-strings + #@ drives = [ -#@ 'Niro096-HQ/2023-01-11--13-47-36', -#@ 'Niro101-HQ/2022-12-25--09-58-33', -#@ 'Niro101-HQ/2022-12-30--09-23-51', -#@ 'Niro101-HQ/2022-12-31--09-57-24', -#@ 'Niro101-HQ/2023-01-01--09-31-43', -#@ 'Niro101-HQ/2023-01-01--12-01-47', -#@ 'Niro101-HQ/2023-03-27--14-58-37', -#@ 'Niro101-HQ/2023-04-02--10-08-57', -#@ 'Niro101-HQ/2023-04-02--13-27-47', -#@ 'Niro101-HQ/2023-04-06--14-43-08', +#@ 'Niro102-HQ/2023-05-08--13-59-22', #@ ] -#@ cameras = ['cam_front_left', 'cam_left_forward', 'cam_right_forward'] - +#@ cameras = [ +#@ 'cam_front_left', +#@ 'cam_left_forward', +#@ 'cam_right_forward', +#@ ] +--- _target_: rbyte.Dataset _recursive_: false -config: - inputs: - #@ for drive in drives: - - id: #@ drive - sources: - frame: - #@ for camera in cameras: - - id: #@ camera - index_column: "image.${.id}.frame_idx" - reader: - _target_: rbyte.io.frame.jpeg.JpegFrameReader - path: "${data_dir}/${.....id}/frames/${..id}.defish.mp4/576x324/{:09d}.jpg" +_convert_: all +inputs: + #@ for input_id in drives: + (@=input_id@): + frame: + #@ for source_id in cameras: + (@=source_id@): + index_column: "ImageMetadata.(@=source_id@).frame_idx" + reader: + _target_: rbyte.io.frame.DirectoryFrameReader + path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg" + frame_decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + #@ end + + table: + path: ${data_dir}/(@=input_id@)/metadata.log + builder: + _target_: rbyte.io.table.TableBuilder + reader: + _target_: rbyte.io.table.yaak.YaakMetadataTableReader + _recursive_: false + _convert_: all + fields: + rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + frame_idx: polars.UInt32 + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.table.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + speed: polars.Float32 + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + + merger: + _target_: rbyte.io.table.TableMerger + separator: "." + merge: + ImageMetadata.(@=cameras[0]@): + time_stamp: + method: ref + + #@ for camera in cameras[1:]: + ImageMetadata.(@=camera@): + time_stamp: + method: ref + + frame_idx: + method: asof + tolerance: 10ms + strategy: nearest #@ end - table: - path: ${data_dir}/${...id}/metadata.log - builder: - _target_: rbyte.io.table.yaak.YaakMetadataTableBuilder - _recursive_: false - config: - cameras: #@ cameras - select: - - message: - type: rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata - alias: image - fields: - - name: time_stamp - - - name: frame_idx - merge: - method: asof - tolerance: 30ms - - - name: camera_name - partition: true - - - message: - type: rbyte.io.table.yaak.proto.can_pb2.VehicleState - alias: vehicle_state - fields: - - name: time_stamp - - - name: turn_signal - merge: - method: asof - tolerance: 100ms - - - message: - type: rbyte.io.table.yaak.proto.can_pb2.VehicleMotion - alias: vehicle_motion - fields: - - name: time_stamp - - - name: speed - merge: - method: interp - - - name: gas_pedal_normalized - merge: - method: interp - - - name: brake_pedal_normalized - merge: - method: interp - - - name: steering_angle_normalized - merge: - method: interp - - - name: gear - merge: - method: asof - tolerance: 100ms - - - message: - type: rbyte.io.table.yaak.proto.sensor_pb2.Gnss - alias: gnss - fields: - - name: time_stamp - - - name: heading - merge: - method: asof - tolerance: 100ms - - merge: - reference: - key: - [ - rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata, - cam_front_left, - ] - column: time_stamp - - filter: | - `vehicle_motion.gear` == 3 - - cache: - _target_: rbyte.utils.DataframeDiskCache - directory: /tmp/rbyte-cache - size_limit: "10 GiB" - #@ end - - samples: - builder: - _target_: rbyte.sample.builder.GreedySampleTableBuilder - config: - index_column: image.cam_front_left.frame_idx - length: 1 - stride: 1 - min_step: 1 - filter: !!null + VehicleMotion: + time_stamp: + method: ref + speed: + method: interp + gear: + method: asof + tolerance: 100ms + + filter: | + `VehicleMotion.gear` == '3' + + cache: + _target_: rbyte.utils.dataframe.DataframeDiskCache + directory: /tmp/rbyte-cache + size_limit: 1GiB + #@ end + +sample_builder: + _target_: rbyte.sample.builder.GreedySampleTableBuilder + index_column: ImageMetadata.(@=cameras[0]@).frame_idx + length: 6 + stride: 1 + min_step: 6 + filter: | + array_mean(`VehicleMotion.speed`) > 47 diff --git a/examples/config_templates/frame_reader/directory.yaml b/examples/config_templates/frame_reader/directory.yaml new file mode 100644 index 0000000..7e02999 --- /dev/null +++ b/examples/config_templates/frame_reader/directory.yaml @@ -0,0 +1,10 @@ +--- +_target_: rbyte.io.frame.DirectoryFrameReader +_recursive_: true +path: ??? +frame_decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true diff --git a/examples/config_templates/frame_reader/mcap.yaml b/examples/config_templates/frame_reader/mcap.yaml new file mode 100644 index 0000000..6e3e886 --- /dev/null +++ b/examples/config_templates/frame_reader/mcap.yaml @@ -0,0 +1,16 @@ +--- +_target_: rbyte.io.frame.mcap.McapFrameReader +_recursive_: true +path: ??? +topic: ??? +message_decoder_factory: + _target_: mcap_protobuf.decoder.DecoderFactory + +frame_decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + +validate_crcs: false diff --git a/examples/config_templates/logger/rerun/carla.yaml b/examples/config_templates/logger/rerun/carla.yaml new file mode 100644 index 0000000..858d31f --- /dev/null +++ b/examples/config_templates/logger/rerun/carla.yaml @@ -0,0 +1,23 @@ +#@yaml/text-templated-strings + +#@ cameras = [ +#@ 'cam_front_left', +#@ ] + +--- +_target_: rbyte.viz.loggers.RerunLogger +_recursive_: true +_convert_: all +schema: + frame: + #@ for camera in cameras: + (@=camera@): rerun.components.ImageBufferBatch + #@ end + + table: + frame_idx: rerun.TimeSequenceColumn + control.brake: rerun.components.ScalarBatch + control.steer: rerun.components.ScalarBatch + control.throttle: rerun.components.ScalarBatch + state.acceleration.value: rerun.components.ScalarBatch + state.velocity.value: rerun.components.ScalarBatch diff --git a/examples/config_templates/logger/rerun/mcap_protobuf.yaml b/examples/config_templates/logger/rerun/mcap_protobuf.yaml new file mode 100644 index 0000000..820391b --- /dev/null +++ b/examples/config_templates/logger/rerun/mcap_protobuf.yaml @@ -0,0 +1,24 @@ +#@yaml/text-templated-strings + +#@ cameras = [ +#@ 'CAM_FRONT', +#@ 'CAM_FRONT_LEFT', +#@ 'CAM_FRONT_RIGHT', +#@ ] +--- +_target_: rbyte.viz.loggers.RerunLogger +schema: + frame: + #@ for camera in cameras: + (@=camera@): rerun.components.ImageBufferBatch + #@ end + + table: + /(@=cameras[0]@)/image_rect_compressed/log_time: rerun.TimeNanosColumn + #@ for camera in cameras: + /(@=camera@)/image_rect_compressed/frame_idx: rerun.TimeSequenceColumn + #@ end + /gps/log_time: rerun.TimeNanosColumn + /gps/latitude: rerun.components.ScalarBatch + /gps/longitude: rerun.components.ScalarBatch + /gps/altitude: rerun.components.ScalarBatch diff --git a/examples/config_templates/logger/rerun/yaak.yaml b/examples/config_templates/logger/rerun/yaak.yaml new file mode 100644 index 0000000..845d287 --- /dev/null +++ b/examples/config_templates/logger/rerun/yaak.yaml @@ -0,0 +1,23 @@ +#@yaml/text-templated-strings + +#@ cameras = [ +#@ 'cam_front_left', +#@ 'cam_left_forward', +#@ 'cam_right_forward', +#@ ] + +--- +_target_: rbyte.viz.loggers.RerunLogger +schema: + frame: + #@ for camera in cameras: + (@=camera@): rerun.components.ImageBufferBatch + #@ end + + table: + #@ for camera in cameras: + ImageMetadata.(@=camera@).frame_idx: rerun.TimeSequenceColumn + ImageMetadata.(@=camera@).time_stamp: rerun.TimeNanosColumn + #@ end + VehicleMotion.time_stamp: rerun.TimeNanosColumn + VehicleMotion.speed: rerun.components.ScalarBatch diff --git a/examples/config_templates/logger/rerun_carla.yaml b/examples/config_templates/logger/rerun_carla.yaml deleted file mode 100644 index d8b0d02..0000000 --- a/examples/config_templates/logger/rerun_carla.yaml +++ /dev/null @@ -1,16 +0,0 @@ ---- -_target_: rbyte.viz.loggers.RerunLogger -_recursive_: true -_convert_: all -config: - log_schema: - frame: - cam_front_left: rerun.components.ImageBufferBatch - - table: - frame_idx: rerun.TimeSequenceColumn - control.brake: rerun.components.ScalarBatch - control.steer: rerun.components.ScalarBatch - control.throttle: rerun.components.ScalarBatch - state.acceleration.value: rerun.components.ScalarBatch - state.velocity.value: rerun.components.ScalarBatch diff --git a/examples/config_templates/logger/rerun_yaak.yaml b/examples/config_templates/logger/rerun_yaak.yaml deleted file mode 100644 index 307c8a1..0000000 --- a/examples/config_templates/logger/rerun_yaak.yaml +++ /dev/null @@ -1,41 +0,0 @@ ---- -_target_: rbyte.viz.loggers.RerunLogger -_convert_: all -_recursive_: true -config: - transforms: - - select: - - [table, image.cam_front_left.time_stamp] - - [table, image.cam_left_forward.time_stamp] - - [table, image.cam_right_forward.time_stamp] - - [table, vehicle_motion.time_stamp] - - [table, vehicle_state.time_stamp] - - [table, gnss.time_stamp] - apply: - _target_: torch.mul - _partial_: true - other: 1000 - - log_schema: - frame: - cam_front_left: rerun.components.ImageBufferBatch - cam_left_forward: rerun.components.ImageBufferBatch - cam_right_forward: rerun.components.ImageBufferBatch - - table: - gnss.heading: rerun.components.ScalarBatch - gnss.time_stamp: rerun.TimeNanosColumn - image.cam_front_left.frame_idx: rerun.TimeSequenceColumn - image.cam_front_left.time_stamp: rerun.TimeNanosColumn - image.cam_left_forward.frame_idx: rerun.TimeSequenceColumn - image.cam_left_forward.time_stamp: rerun.TimeNanosColumn - image.cam_right_forward.frame_idx: rerun.TimeSequenceColumn - image.cam_right_forward.time_stamp: rerun.TimeNanosColumn - vehicle_motion.brake_pedal_normalized: rerun.components.ScalarBatch - vehicle_motion.gas_pedal_normalized: rerun.components.ScalarBatch - vehicle_motion.gear: rerun.components.ScalarBatch - vehicle_motion.speed: rerun.components.ScalarBatch - vehicle_motion.steering_angle_normalized: rerun.components.ScalarBatch - vehicle_motion.time_stamp: rerun.TimeNanosColumn - vehicle_state.time_stamp: rerun.TimeNanosColumn - vehicle_state.turn_signal: rerun.components.ScalarBatch diff --git a/examples/config_templates/read_frames.yaml b/examples/config_templates/read_frames.yaml new file mode 100644 index 0000000..2c1ac51 --- /dev/null +++ b/examples/config_templates/read_frames.yaml @@ -0,0 +1,12 @@ +--- +defaults: + - frame_reader: !!null + - _self_ + +batch_size: 1 +entity_path: ??? + +hydra: + output_subdir: !!null + run: + dir: . diff --git a/examples/config_templates/table.yaml b/examples/config_templates/table.yaml deleted file mode 100644 index 0bc94d9..0000000 --- a/examples/config_templates/table.yaml +++ /dev/null @@ -1,85 +0,0 @@ ---- -path: ??? -builder: - _target_: rbyte.io.table.yaak.YaakMetadataTableBuilder - _recursive_: false - config: - cameras: [cam_front_left] - select: - - message: - type: rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata - alias: image - fields: - - name: time_stamp - - - name: frame_idx - merge: - method: asof - tolerance: 30ms - - - name: camera_name - partition: true - - - message: - type: rbyte.io.table.yaak.proto.can_pb2.VehicleState - alias: vehicle_state - fields: - - name: time_stamp - - - name: turn_signal - merge: - method: asof - tolerance: 100ms - - - message: - type: rbyte.io.table.yaak.proto.can_pb2.VehicleMotion - alias: vehicle_motion - fields: - - name: time_stamp - - - name: speed - merge: - method: interp - - - name: gas_pedal_normalized - merge: - method: interp - - - name: brake_pedal_normalized - merge: - method: interp - - - name: steering_angle_normalized - merge: - method: interp - - - name: gear - merge: - method: asof - tolerance: 100ms - - - message: - type: rbyte.io.table.yaak.proto.sensor_pb2.Gnss - alias: gnss - fields: - - name: time_stamp - - - name: heading - merge: - method: asof - tolerance: 100ms - - merge: - reference: - key: - - rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata - - cam_front_left - column: time_stamp - - filter: !!null - cache: !!null - -writer: - _target_: polars.DataFrame.write_csv - _partial_: true - file: ??? diff --git a/examples/config_templates/table_builder/carla.yaml b/examples/config_templates/table_builder/carla.yaml new file mode 100644 index 0000000..6ee4a92 --- /dev/null +++ b/examples/config_templates/table_builder/carla.yaml @@ -0,0 +1,18 @@ +--- +_target_: rbyte.io.table.carla.CarlaRecordsTableBuilder +_convert_: all +index_column: frame_idx +select: + - control.brake + - control.throttle + - control.steer + - state.velocity.value + - state.acceleration.value + +filter: | + `control.throttle` > 0.5 + +transforms: + - _target_: rbyte.io.table.transforms.FpsResampler + source_fps: 20 + target_fps: 30 diff --git a/examples/config_templates/table_builder/mcap_protobuf.yaml b/examples/config_templates/table_builder/mcap_protobuf.yaml new file mode 100644 index 0000000..d05d421 --- /dev/null +++ b/examples/config_templates/table_builder/mcap_protobuf.yaml @@ -0,0 +1,69 @@ +#@yaml/text-templated-strings + +#@ cameras = [ +#@ 'CAM_FRONT', +#@ 'CAM_FRONT_LEFT', +#@ 'CAM_FRONT_RIGHT', +#@ ] +--- +_target_: rbyte.io.table.TableBuilder +reader: + _target_: rbyte.io.table.mcap.McapProtobufTableReader + _recursive_: false + _convert_: all + fields: + #@ for camera in cameras: + /(@=camera@)/image_rect_compressed: + log_time: + _target_: polars.Datetime + time_unit: ns + #@ end + + /gps: + log_time: + _target_: polars.Datetime + time_unit: ns + + latitude: polars.Float64 + longitude: polars.Float64 + altitude: polars.Float64 + +merger: + _target_: rbyte.io.table.TableMerger + separator: / + merge: + /(@=cameras[0]@)/image_rect_compressed: + log_time: + method: ref + + #@ for camera in cameras[1:]: + /(@=camera@)/image_rect_compressed: + log_time: + method: ref + frame_idx: + method: asof + tolerance: 10ms + strategy: nearest + #@ end + + /gps: + log_time: + method: ref + latitude: + method: asof + tolerance: 1000ms + strategy: nearest + longitude: + method: asof + tolerance: 1000ms + strategy: nearest + altitude: + method: asof + tolerance: 1000ms + strategy: nearest + +filter: !!null +cache: + _target_: rbyte.utils.dataframe.DataframeDiskCache + directory: /tmp/rbyte-cache + size_limit: 1GiB diff --git a/examples/config_templates/table_builder/mcap_ros2.yaml b/examples/config_templates/table_builder/mcap_ros2.yaml new file mode 100644 index 0000000..3137b76 --- /dev/null +++ b/examples/config_templates/table_builder/mcap_ros2.yaml @@ -0,0 +1,62 @@ +#! https://jkk-research.github.io/dataset/jkk_dataset_02 + +#@yaml/text-templated-strings +--- +_target_: rbyte.io.table.TableBuilder +reader: + _target_: rbyte.io.table.mcap.McapRos2TableReader + _recursive_: false + _convert_: all + fields: + /nissan/vehicle_speed: + log_time: + _target_: polars.Datetime + time_unit: ns + + data: + + /nissan/gps/duro/current_pose: + log_time: + _target_: polars.Datetime + time_unit: ns + + pose.position.x: + pose.position.y: + pose.position.z: + pose.orientation.x: + pose.orientation.y: + pose.orientation.z: + pose.orientation.w: + +merger: + _target_: rbyte.io.table.TableMerger + separator: / + merge: + /nissan/vehicle_speed: + log_time: + method: ref + + /nissan/gps/duro/current_pose: + log_time: + method: ref + pose.position.x: + method: interp + pose.position.y: + method: interp + pose.position.z: + method: interp + pose.orientation.x: + method: interp + pose.orientation.y: + method: interp + pose.orientation.z: + method: interp + pose.orientation.w: + method: interp + +filter: !!null + +cache: + _target_: rbyte.utils.dataframe.DataframeDiskCache + directory: /tmp/rbyte-cache + size_limit: 1GiB diff --git a/examples/config_templates/table_builder/yaak.yaml b/examples/config_templates/table_builder/yaak.yaml new file mode 100644 index 0000000..516dce6 --- /dev/null +++ b/examples/config_templates/table_builder/yaak.yaml @@ -0,0 +1,77 @@ +#@yaml/text-templated-strings + +#@ cameras = [ +#@ 'cam_front_left', +#@ 'cam_left_forward', +#@ 'cam_right_forward', +#@ ] +--- +_target_: rbyte.io.table.TableBuilder +reader: + _target_: rbyte.io.table.yaak.YaakMetadataTableReader + _recursive_: false + _convert_: all + fields: + rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + frame_idx: polars.UInt32 + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.table.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + speed: polars.Float32 + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + +merger: + _target_: rbyte.io.table.TableMerger + separator: "." + merge: + ImageMetadata.(@=cameras[0]@): + time_stamp: + method: ref + + #@ for camera in cameras[1:]: + ImageMetadata.(@=camera@): + time_stamp: + method: ref + + frame_idx: + method: asof + tolerance: 10ms + strategy: nearest + #@ end + + VehicleMotion: + time_stamp: + method: ref + speed: + method: interp + gear: + method: asof + tolerance: 100ms + +filter: | + `VehicleMotion.gear` == '3' + +cache: + _target_: rbyte.utils.dataframe.DataframeDiskCache + directory: /tmp/rbyte-cache + size_limit: 1GiB diff --git a/examples/config_templates/visualize.yaml b/examples/config_templates/visualize.yaml index e5185a4..d0fdabd 100644 --- a/examples/config_templates/visualize.yaml +++ b/examples/config_templates/visualize.yaml @@ -1,7 +1,7 @@ --- defaults: - - /dataset: yaak - - /logger: console + - /dataset: !!null + - /logger: !!null - _self_ data_dir: ??? @@ -9,7 +9,7 @@ dataloader: _target_: torch.utils.data.DataLoader dataset: ${dataset} shuffle: false - batch_size: 1 + batch_size: 4 collate_fn: _target_: rbyte.utils.dataloader.collate_identity _partial_: true @@ -18,3 +18,8 @@ dataloader: pin_memory: false persistent_workers: true multiprocessing_context: forkserver + +hydra: + output_subdir: !!null + run: + dir: . diff --git a/justfile b/justfile index acd44e0..2efdc64 100644 --- a/justfile +++ b/justfile @@ -1,4 +1,3 @@ -export HYDRA_FULL_ERROR := "1" export PYTHONOPTIMIZE := "1" export HATCH_BUILD_CLEAN := "1" @@ -20,8 +19,8 @@ build: build-protos: uvx --from hatch hatch build --clean --hooks-only --target sdist -pre-commit: - uvx pre-commit run --all-files --color=always +pre-commit *ARGS: + uvx pre-commit run --all-files --color=always {{ ARGS }} generate-example-config: ytt --ignore-unknown-comments \ @@ -35,13 +34,26 @@ visualize *ARGS: generate-example-config uv run rbyte-visualize \ --config-path {{ justfile_directory() }}/examples/config \ --config-name visualize.yaml \ + hydra/hydra_logging=disabled \ + hydra/job_logging=disabled \ {{ ARGS }} [group('scripts')] build-table *ARGS: generate-example-config uv run rbyte-build-table \ --config-path {{ justfile_directory() }}/examples/config \ - --config-name table.yaml \ + --config-name build_table.yaml \ + hydra/hydra_logging=disabled \ + hydra/job_logging=disabled \ + {{ ARGS }} + +[group('scripts')] +read-frames *ARGS: generate-example-config + uv run rbyte-read-frames \ + --config-path {{ justfile_directory() }}/examples/config \ + --config-name read_frames.yaml \ + hydra/hydra_logging=disabled \ + hydra/job_logging=disabled \ {{ ARGS }} # rerun server and viewer diff --git a/pyproject.toml b/pyproject.toml index 9908ef7..622e498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,22 +1,24 @@ [project] name = "rbyte" -version = "0.1.0" +version = "0.2.0" description = "Multimodal dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] dependencies = [ - "torch>=2.4.0", - "tensordict @ git+https://github.com/pytorch/tensordict.git", - "more-itertools", - "protobuf", - "pydantic>=2.8.2", - "diskcache", - "polars>=1.5.0", - "parse", - "hydra-core", - "simplejpeg", - "jaxtyping", - "structlog", + "tensordict @ git+https://github.com/pytorch/tensordict.git@85b6b81", + "torch>=2.4.1", + "polars>=1.8.0", + "pydantic>=2.9.2", + "more-itertools>=10.5.0", + "hydra-core>=1.3.2", + "optree>=0.12.1", + "cachetools>=5.5.0", + "diskcache>=5.6.3", + "jaxtyping>=0.2.34", + "parse>=1.20.2", + "structlog>=24.4.0", + "xxhash>=3.5.0", + "tqdm>=4.66.5", ] readme = "README.md" requires-python = ">=3.12" @@ -32,11 +34,27 @@ classifiers = [ ] [project.optional-dependencies] -build = ["hatchling>=1.25.0", "grpcio-tools==1.62.2", "protoletariat==3.2.19"] -visualize = ["rerun-sdk>=0.18.0", "tqdm"] +build = [ + "hatchling>=1.25.0", + "grpcio-tools", + "protoletariat==3.2.19", + "protobuf", +] +visualize = ["rerun-sdk>=0.18.2"] +mcap = [ + "mcap>=1.1.1", + "mcap-protobuf-support>=0.5.1", + "mcap-ros2-support>=0.5.3", + "protobuf", + "ptars>=0.0.2rc2", + "pyarrow-stubs>=17.6", +] +yaak = ["protobuf", "ptars>=0.0.2rc2"] +jpeg = ["simplejpeg>=1.7.6"] [project.scripts] rbyte-build-table = 'rbyte.scripts.build_table:main' +rbyte-read-frames = 'rbyte.scripts.read_frames:main' rbyte-visualize = 'rbyte.scripts.visualize:main' [build-system] @@ -45,11 +63,13 @@ build-backend = "hatchling.build" [tool.uv] dev-dependencies = [ - "wat-inspector>=0.3.3", - "lovely-tensors>=0.1.16", + "wat-inspector>=0.4.0", + "lovely-tensors>=0.1.17", "pudb>=2024.1.2", ] +[tool.uv.sources] + [tool.hatch.metadata] allow-direct-references = true diff --git a/src/rbyte/__init__.py b/src/rbyte/__init__.py index f5ff626..fd81618 100644 --- a/src/rbyte/__init__.py +++ b/src/rbyte/__init__.py @@ -1,3 +1,7 @@ -from .dataset import Dataset, DatasetConfig +from importlib.metadata import version -__all__ = ["Dataset", "DatasetConfig"] +from .dataset import Dataset + +__version__ = version(__package__ or __name__) + +__all__ = ["Dataset", "__version__"] diff --git a/src/rbyte/batch/batch.py b/src/rbyte/batch/batch.py index 33704f9..9592708 100644 --- a/src/rbyte/batch/batch.py +++ b/src/rbyte/batch/batch.py @@ -3,14 +3,14 @@ from torch import Tensor -@tensorclass # pyright: ignore[reportArgumentType] +@tensorclass # pyright: ignore[reportUntypedClassDecorator, reportArgumentType, reportCallIssue] class BatchMeta: sample_idx: Int[Tensor, "b 1"] # pyright: ignore[reportUninitializedInstanceVariable] - input_id: NonTensorData # pyright: ignore[reportGeneralTypeIssues, reportUninitializedInstanceVariable] + input_id: NonTensorData # pyright: ignore[reportUninitializedInstanceVariable] -@tensorclass # pyright: ignore[reportArgumentType] +@tensorclass # pyright: ignore[reportUntypedClassDecorator, reportArgumentType, reportCallIssue] class Batch: - meta: BatchMeta # pyright: ignore[reportUninitializedInstanceVariable, reportGeneralTypeIssues] + meta: BatchMeta # pyright: ignore[reportUninitializedInstanceVariable] frame: TensorDict # pyright: ignore[reportUninitializedInstanceVariable] table: TensorDict # pyright: ignore[reportUninitializedInstanceVariable] diff --git a/src/rbyte/config/base.py b/src/rbyte/config/base.py index 4b6ecc7..589a12e 100644 --- a/src/rbyte/config/base.py +++ b/src/rbyte/config/base.py @@ -7,19 +7,19 @@ class BaseModel(_BaseModel): - class Config: - frozen = True - extra = "forbid" - validate_assignment = True + model_config = ConfigDict( + frozen=True, + extra="forbid", + validate_assignment=True, + ignored_types=(cached_property,), + ) T = TypeVar("T") class HydraConfig[T](BaseModel): - model_config = ConfigDict( - frozen=True, extra="allow", ignored_types=(cached_property,) - ) + model_config = ConfigDict(extra="allow") target: ImportString[type[T]] = Field(alias="_target_") diff --git a/src/rbyte/dataset/dataset.py b/src/rbyte/dataset.py similarity index 52% rename from src/rbyte/dataset/dataset.py rename to src/rbyte/dataset.py index f207b6a..8c1ddac 100644 --- a/src/rbyte/dataset/dataset.py +++ b/src/rbyte/dataset.py @@ -1,25 +1,46 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from enum import StrEnum, unique +from functools import cache +from typing import Annotated +import more_itertools as mit import polars as pl import torch -from hydra.utils import instantiate -from polars._utils.getitem import ( - _select_rows_by_index, # pyright: ignore[reportPrivateUsage] # noqa: PLC2701 -) +from pydantic import ConfigDict, Field, FilePath, StringConstraints, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars from tensordict import TensorDict from torch.utils.data import Dataset as TorchDataset from rbyte.batch import Batch, BatchMeta - -from .config import DatasetConfig, FrameSourceConfig, SourcesConfig, TableSourceConfig +from rbyte.config.base import BaseModel, HydraConfig +from rbyte.io.frame.base import FrameReader +from rbyte.io.table.base import TableBuilderBase +from rbyte.sample.base import SampleTableBuilder __all__ = ["Dataset"] logger = get_logger(__name__) +type Id = Annotated[ + str, StringConstraints(strip_whitespace=True, pattern=r"^[\x00-\x7F]+$") +] + + +class FrameSourceConfig(BaseModel): + reader: HydraConfig[FrameReader] + index_column: str + + +class TableSourceConfig(BaseModel): + path: FilePath + builder: HydraConfig[TableBuilderBase] + + +class SourcesConfig(BaseModel): + frame: Mapping[Id, FrameSourceConfig] = Field(min_length=1) + table: TableSourceConfig | None = None + @unique class Column(StrEnum): @@ -32,22 +53,26 @@ class Column(StrEnum): class Dataset(TorchDataset[TensorDict]): - def __init__(self, config: object) -> None: - super().__init__() - - config = DatasetConfig.model_validate(config) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + inputs: Annotated[Mapping[Id, SourcesConfig], Field(min_length=1)], + sample_builder: HydraConfig[SampleTableBuilder], + ) -> None: + logger.debug("initializing dataset") - samples: dict[str, pl.LazyFrame] = {} - sample_builder = config.samples.builder.instantiate() - for _input in config.inputs: - with bound_contextvars(input=_input.id): - table = self._build_table(_input.sources) - samples[_input.id] = sample_builder.build(table) + super().__init__() + _sample_builder = sample_builder.instantiate() + samples: Mapping[str, pl.LazyFrame] = {} + for input_id, input_cfg in inputs.items(): + with bound_contextvars(input_id=input_id): + table = self._build_table(input_cfg) + samples[input_id] = _sample_builder.build(table) logger.debug( - "processed", + "built samples", rows=table.select(pl.len()).collect().item(), - samples=samples[_input.id].select(pl.len()).collect().item(), + samples=samples[input_id].select(pl.len()).collect().item(), ) input_id_enum = pl.Enum(sorted(samples)) @@ -73,13 +98,19 @@ def __init__(self, config: object) -> None: pl.LazyFrame( [ { - Column.input_id: _input.id, + Column.input_id: input_id, (k := "source"): [ - source.model_dump(by_alias=True) - for source in _input.sources.frame + source_cfg.model_dump(exclude={"reader"}) + | { + "id": source_id, + "reader": source_cfg.reader.model_dump_json( + by_alias=True + ), + } + for source_id, source_cfg in input_cfg.frame.items() ], } - for _input in config.inputs + for input_id, input_cfg in inputs.items() ], schema_overrides={Column.input_id: input_id_enum}, ) @@ -92,17 +123,19 @@ def __init__(self, config: object) -> None: @classmethod def _build_table(cls, sources: SourcesConfig) -> pl.LazyFrame: + logger.debug("building table") + match sources: - case SourcesConfig( - frame=[FrameSourceConfig(reader=reader, index_column=index_column)], - table=None, - ): - frame_reader = reader.instantiate() + case SourcesConfig(frame=frame_sources, table=None) if len( + frame_sources + ) == 1: + frame_source = mit.one(frame_sources.values()) + frame_reader = frame_source.reader.instantiate() frame_idxs = pl.Series( - name=index_column, - values=frame_reader.get_available_indices(), + name=frame_source.index_column, + values=frame_reader.get_available_indexes(), dtype=pl.UInt32, - ) + ).sort() return pl.LazyFrame(frame_idxs) @@ -110,21 +143,23 @@ def _build_table(cls, sources: SourcesConfig) -> pl.LazyFrame: frame=frame_sources, table=TableSourceConfig(path=path, builder=builder) ): table_builder = builder.instantiate() - table_df = table_builder.build(path).lazy() - schema = table_df.collect_schema() + table = table_builder.build(path).lazy() + schema = table.collect_schema() - for frame_source in frame_sources: + for frame_source_id, frame_source in frame_sources.items(): + logger.debug("pruning table", frame_source=frame_source_id) frame_reader = frame_source.reader.instantiate() frame_idxs = pl.Series( name=(col := frame_source.index_column), - values=frame_reader.get_available_indices(), + values=frame_reader.get_available_indexes(), dtype=schema[col], - ) - table_df = table_df.join( + ).sort() + + table = table.join( pl.LazyFrame(frame_idxs), on=frame_idxs.name, how="semi" ) - return table_df + return table case _: raise NotImplementedError @@ -137,14 +172,18 @@ def samples(self) -> pl.DataFrame: def frame_sources(self) -> pl.DataFrame: return self._frame_sources - def __getitems__(self, idxs: Sequence[int]) -> Batch: # pyright: ignore[reportGeneralTypeIssues, reportUnknownParameterType] # noqa: PLW3201 - samples = _select_rows_by_index( - self.samples, pl.Series(values=idxs, dtype=pl.UInt32) - ) + @cache # noqa: B019 + def _get_frame_reader(self, reader_json: str) -> FrameReader: # noqa: PLR6301 + return HydraConfig[FrameReader].model_validate_json(reader_json).instantiate() + + def __getitems__(self, indexes: Sequence[int]) -> Batch: # noqa: PLW3201 + samples = self.samples[indexes] + batch_size = [samples.height] - meta = BatchMeta( # pyright: ignore[reportCallIssue, reportUnknownVariableType] + meta = BatchMeta( sample_idx=samples[Column.sample_idx].to_torch(), # pyright: ignore[reportCallIssue] input_id=samples[Column.input_id].to_list(), # pyright: ignore[reportCallIssue] + batch_size=batch_size, # pyright: ignore[reportCallIssue] ) frame_source_idx_cols = self._frame_sources[Column.source_index_column].unique() @@ -165,25 +204,24 @@ def __getitems__(self, idxs: Sequence[int]) -> Batch: # pyright: ignore[reportG frames = TensorDict( { row[Column.source_id]: torch.stack([ - instantiate(reader).read(frame_idxs) + self._get_frame_reader(reader).read(frame_idxs) for (reader, frame_idxs) in zip( row[Column.source_reader], row[Column.frame_idx], strict=True ) ]) for row in frame_sources.collect().iter_rows(named=True) }, - batch_size=[len(idxs)], + batch_size=batch_size, ) table = TensorDict( samples.select( # pyright: ignore[reportArgumentType] - pl.exclude(Column.sample_idx, Column.input_id) - .arr.to_list() - .to_physical() - ).to_dict() + pl.exclude(Column.sample_idx, Column.input_id).to_physical() + ).to_dict(as_series=False), + batch_size=batch_size, ) - return Batch(meta=meta, frame=frames, table=table).auto_batch_size_(1) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportUnknownMemberType] + return Batch(meta=meta, frame=frames, table=table, batch_size=batch_size) # pyright: ignore[reportCallIssue] def __len__(self) -> int: return len(self.samples) diff --git a/src/rbyte/dataset/__init__.py b/src/rbyte/dataset/__init__.py deleted file mode 100644 index 62c7e2a..0000000 --- a/src/rbyte/dataset/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .config import DatasetConfig -from .dataset import Dataset - -__all__ = ["Dataset", "DatasetConfig"] diff --git a/src/rbyte/dataset/config.py b/src/rbyte/dataset/config.py deleted file mode 100644 index 8584d4d..0000000 --- a/src/rbyte/dataset/config.py +++ /dev/null @@ -1,71 +0,0 @@ -from collections.abc import Mapping -from functools import cached_property -from operator import attrgetter -from typing import Annotated - -from more_itertools import all_unique -from pydantic import Field, FilePath, StringConstraints, computed_field, field_validator - -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.frame.base import FrameReader -from rbyte.io.table.base import TableBuilder -from rbyte.sample.base import SampleTableBuilder - - -class FrameSourceConfig(BaseModel): - id: str - reader: HydraConfig[FrameReader] - index_column: str - - -class TableSourceConfig(BaseModel): - path: FilePath - builder: HydraConfig[TableBuilder] - - -class SourcesConfig(BaseModel): - frame: tuple[FrameSourceConfig, ...] = Field(min_length=1) - table: TableSourceConfig | None = None - - @field_validator("frame", mode="after") - @classmethod - def validate_frame_sources( - cls, sources: tuple[FrameSourceConfig, ...] - ) -> tuple[FrameSourceConfig, ...]: - if not all_unique(sources, key=attrgetter("id")): - msg = "frame source ids not unique" - raise ValueError(msg) - - return sources - - -class InputConfig(BaseModel): - id: Annotated[ - str, StringConstraints(strip_whitespace=True, pattern=r"^[\x00-\x7F]+$") - ] - sources: SourcesConfig - - -class SamplesConfig(BaseModel): - builder: HydraConfig[SampleTableBuilder] - - -class DatasetConfig(BaseModel): - inputs: tuple[InputConfig, ...] = Field(min_length=1) - samples: SamplesConfig - - @field_validator("inputs", mode="after") - @classmethod - def validate_inputs( - cls, inputs: tuple[InputConfig, ...] - ) -> tuple[InputConfig, ...]: - if not all_unique(inputs, key=attrgetter("id")): - msg = "input ids not unique" - raise ValueError(msg) - - return inputs - - @computed_field - @cached_property - def inputs_by_id(self) -> Mapping[str, InputConfig]: - return {x.id: x for x in self.inputs} diff --git a/src/rbyte/io/frame/__init__.py b/src/rbyte/io/frame/__init__.py index e69de29..ef1af94 100644 --- a/src/rbyte/io/frame/__init__.py +++ b/src/rbyte/io/frame/__init__.py @@ -0,0 +1,3 @@ +from .directory import DirectoryFrameReader + +__all__ = ["DirectoryFrameReader"] diff --git a/src/rbyte/io/frame/base.py b/src/rbyte/io/frame/base.py index d8d73c7..11cc562 100644 --- a/src/rbyte/io/frame/base.py +++ b/src/rbyte/io/frame/base.py @@ -7,5 +7,5 @@ @runtime_checkable class FrameReader(Protocol): - def read(self, idxs: Iterable[int]) -> Shaped[Tensor, "b h w c"]: ... - def get_available_indices(self) -> Sequence[int]: ... + def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: ... + def get_available_indexes(self) -> Sequence[int]: ... diff --git a/src/rbyte/io/frame/directory/__init__.py b/src/rbyte/io/frame/directory/__init__.py new file mode 100644 index 0000000..08a85f4 --- /dev/null +++ b/src/rbyte/io/frame/directory/__init__.py @@ -0,0 +1,3 @@ +from .reader import DirectoryFrameReader + +__all__ = ["DirectoryFrameReader"] diff --git a/src/rbyte/io/frame/directory/reader.py b/src/rbyte/io/frame/directory/reader.py new file mode 100644 index 0000000..aabd624 --- /dev/null +++ b/src/rbyte/io/frame/directory/reader.py @@ -0,0 +1,47 @@ +from collections.abc import Callable, Iterable, Sequence +from functools import cached_property +from os import PathLike +from pathlib import Path +from typing import override + +import numpy.typing as npt +import parse +import torch +from jaxtyping import UInt8 +from pydantic import validate_call +from torch import Tensor + +from rbyte.io.frame.base import FrameReader + + +class DirectoryFrameReader(FrameReader): + @validate_call + def __init__( + self, path: PathLike[str], frame_decoder: Callable[[bytes], npt.ArrayLike] + ) -> None: + super().__init__() + + self._path = Path(path) + self._frame_decoder = frame_decoder + + @cached_property + def _path_posix(self) -> str: + return self._path.as_posix() + + def _decode(self, path: str) -> npt.ArrayLike: + with Path(path).open("rb") as f: + return self._frame_decoder(f.read()) + + @override + def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: + paths = map(self._path_posix.format, indexes) + frames_np = map(self._decode, paths) + frames_tch = map(torch.from_numpy, frames_np) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + + return torch.stack(list(frames_tch)) + + @override + def get_available_indexes(self) -> Sequence[int]: + parser = parse.compile(self._path.name) # pyright: ignore[reportUnknownMemberType] + filenames = (path.name for path in self._path.parent.iterdir()) + return [res[0] for res in map(parser.parse, filenames) if res] # pyright: ignore[reportUnknownVariableType, reportIndexIssue, reportUnknownArgumentType, reportUnknownMemberType] diff --git a/src/rbyte/io/frame/jpeg/__init__.py b/src/rbyte/io/frame/jpeg/__init__.py deleted file mode 100644 index 751bfba..0000000 --- a/src/rbyte/io/frame/jpeg/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import JpegFrameReader - -__all__ = ["JpegFrameReader"] diff --git a/src/rbyte/io/frame/jpeg/reader.py b/src/rbyte/io/frame/jpeg/reader.py deleted file mode 100644 index 5a1f5cc..0000000 --- a/src/rbyte/io/frame/jpeg/reader.py +++ /dev/null @@ -1,45 +0,0 @@ -from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property -from os import PathLike -from pathlib import Path -from typing import Any, override - -import parse -import torch -from jaxtyping import UInt8 -from simplejpeg import decode_jpeg # pyright: ignore[reportUnknownVariableType] -from torch import Tensor - -from rbyte.io.frame.base import FrameReader - - -class JpegFrameReader(FrameReader): - def __init__( - self, path: PathLike[str], decode_kwargs: Mapping[str, Any] | None = None - ) -> None: - super().__init__() - - self.path = Path(path) - self.decode_kwargs = decode_kwargs or {} - - @cached_property - def _path_posix(self) -> str: - return self.path.as_posix() - - def _decode_jpeg(self, path: str) -> object: - with Path(path).open("rb") as f: - return decode_jpeg(f.read(), **self.decode_kwargs) # pyright: ignore[reportUnknownVariableType] - - @override - def read(self, idxs: Iterable[int]) -> UInt8[Tensor, "b h w c"]: - paths = map(self._path_posix.format, idxs) - frames_np = map(self._decode_jpeg, paths) - frames_tch = map(torch.from_numpy, frames_np) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] - - return torch.stack(list(frames_tch)) - - @override - def get_available_indices(self) -> Sequence[int]: - parser = parse.compile(self.path.name) # pyright: ignore[reportUnknownMemberType] - filenames = (path.name for path in self.path.parent.iterdir()) - return [res[0] for res in map(parser.parse, filenames) if res] # pyright: ignore[reportUnknownVariableType, reportIndexIssue, reportUnknownArgumentType, reportUnknownMemberType] diff --git a/src/rbyte/io/frame/mcap/__init__.py b/src/rbyte/io/frame/mcap/__init__.py new file mode 100644 index 0000000..1f57037 --- /dev/null +++ b/src/rbyte/io/frame/mcap/__init__.py @@ -0,0 +1,3 @@ +from .reader import McapFrameReader + +__all__ = ["McapFrameReader"] diff --git a/src/rbyte/io/frame/mcap/reader.py b/src/rbyte/io/frame/mcap/reader.py new file mode 100644 index 0000000..54827a5 --- /dev/null +++ b/src/rbyte/io/frame/mcap/reader.py @@ -0,0 +1,169 @@ +from collections.abc import Callable, Iterable, Mapping, Sequence +from dataclasses import dataclass +from functools import cached_property +from mmap import ACCESS_READ, mmap +from typing import IO, override + +import more_itertools as mit +import numpy.typing as npt +import torch +from jaxtyping import Shaped +from mcap.data_stream import ReadDataStream +from mcap.decoder import DecoderFactory +from mcap.opcode import Opcode +from mcap.reader import SeekingReader +from mcap.records import Chunk, ChunkIndex, Message +from mcap.stream_reader import get_chunk_data_stream +from pydantic import ConfigDict, FilePath, validate_call +from structlog import get_logger +from structlog.contextvars import bound_contextvars +from torch import Tensor + +from rbyte.io.frame.base import FrameReader + +logger = get_logger(__name__) + + +@dataclass(frozen=True) +class MessageIndex: + chunk_start_offset: int + message_start_offset: int + message_length: int + + +class McapFrameReader(FrameReader): + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + path: FilePath, + topic: str, + message_decoder_factory: DecoderFactory, + frame_decoder: Callable[[bytes], npt.ArrayLike], + validate_crcs: bool = False, # noqa: FBT001, FBT002 + ) -> None: + super().__init__() + + with bound_contextvars( + path=path.as_posix(), + topic=topic, + message_decoder_factory=message_decoder_factory, + ): + self._path = path + self._validate_crcs = validate_crcs + + summary = SeekingReader( + stream=self._file, validate_crcs=self._validate_crcs + ).get_summary() + + if summary is None: + logger.error(msg := "missing summary") + raise ValueError(msg) + + self._channel = mit.one( + channel + for channel in summary.channels.values() + if channel.topic == topic + ) + + message_decoder = message_decoder_factory.decoder_for( + message_encoding=self._channel.message_encoding, + schema=summary.schemas[self._channel.schema_id], + ) + + if message_decoder is None: + logger.error(msg := "missing message decoder") + raise RuntimeError(msg) + + self._message_decoder = message_decoder + self._chunk_indexes = tuple( + chunk_index + for chunk_index in summary.chunk_indexes + if self._channel.id in chunk_index.message_index_offsets + ) + self._frame_decoder = frame_decoder + + @property + def _file(self) -> IO[bytes]: + match getattr(self, "_mmap", None): + case mmap(closed=False): + pass + + case None | mmap(closed=True): + with self._path.open("rb") as f: + self._mmap = mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) + + case _: + raise RuntimeError + + return self._mmap # pyright: ignore[reportReturnType] + + @override + def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: + frames: Mapping[int, npt.ArrayLike] = {} + + message_indexes_by_chunk_start_offset: Mapping[ + int, Iterable[tuple[int, MessageIndex]] + ] = mit.map_reduce( + zip(indexes, (self._message_indexes[idx] for idx in indexes), strict=True), + keyfunc=lambda x: x[1].chunk_start_offset, + ) + + for ( + chunk_start_offset, + chunk_message_indexes, + ) in message_indexes_by_chunk_start_offset.items(): + self._file.seek(chunk_start_offset + 1 + 8) # pyright: ignore[reportUnusedCallResult] + chunk = Chunk.read(ReadDataStream(self._file)) + stream, _ = get_chunk_data_stream(chunk, validate_crc=self._validate_crcs) + for frame_index, message_index in sorted( + chunk_message_indexes, key=lambda x: x[1].message_start_offset + ): + stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult] + message = Message.read(stream, message_index.message_length) + decoded_message = self._message_decoder(message.data) + frames[frame_index] = self._frame_decoder(decoded_message.data) + + return torch.stack([torch.from_numpy(frames[idx]) for idx in indexes]) # pyright: ignore[reportUnknownMemberType] + + @override + def get_available_indexes(self) -> Sequence[int]: + return range(len(self._message_indexes)) + + @cached_property + def _message_indexes(self) -> Sequence[MessageIndex]: + return tuple( + self._build_message_indexes( + self._file, + chunk_indexes=self._chunk_indexes, + channel_id=self._channel.id, + validate_crc=self._validate_crcs, + ) + ) + + @staticmethod + def _build_message_indexes( + f: IO[bytes], + *, + chunk_indexes: Iterable[ChunkIndex], + channel_id: int, + validate_crc: bool, + ) -> Iterable[MessageIndex]: + for chunk_index in chunk_indexes: + f.seek(chunk_index.chunk_start_offset + 1 + 8) # pyright: ignore[reportUnusedCallResult] + chunk = Chunk.read(ReadDataStream(f)) + stream, stream_length = get_chunk_data_stream(chunk, validate_crc) + + while stream.count < stream_length: + opcode = stream.read1() + length = stream.read8() + match opcode: + case Opcode.MESSAGE: + if Message.read(stream, length).channel_id == channel_id: + yield MessageIndex( + chunk_start_offset=chunk_index.chunk_start_offset, + message_start_offset=stream.count - length, + message_length=length, + ) + + case _: + stream.read(length) # pyright: ignore[reportUnusedCallResult] diff --git a/src/rbyte/io/table/__init__.py b/src/rbyte/io/table/__init__.py index e69de29..3c89ca5 100644 --- a/src/rbyte/io/table/__init__.py +++ b/src/rbyte/io/table/__init__.py @@ -0,0 +1,4 @@ +from .builder import TableBuilder +from .merger import TableMerger + +__all__ = ["TableBuilder", "TableMerger"] diff --git a/src/rbyte/io/table/base.py b/src/rbyte/io/table/base.py index 6e0b8c4..7680822 100644 --- a/src/rbyte/io/table/base.py +++ b/src/rbyte/io/table/base.py @@ -1,9 +1,29 @@ +from collections.abc import Hashable, Mapping from os import PathLike from typing import Protocol, runtime_checkable import polars as pl +type Table = pl.DataFrame + + +@runtime_checkable +class TableBuilderBase(Protocol): + def build(self, path: PathLike[str]) -> Table: ... + + +@runtime_checkable +class TableReaderBase(Hashable, Protocol): + def read(self, path: PathLike[str]) -> Mapping[str, Table]: ... + + +@runtime_checkable +class TableMergerBase(Hashable, Protocol): + def merge(self, src: Mapping[str, Table]) -> Table: ... + @runtime_checkable -class TableBuilder(Protocol): - def build(self, path: PathLike[str]) -> pl.DataFrame: ... +class TableCacheBase(Protocol): + def __contains__(self, key: Hashable) -> bool: ... + def get(self, key: Hashable) -> Table | None: ... + def set(self, key: Hashable, value: Table) -> bool: ... diff --git a/src/rbyte/io/table/builder.py b/src/rbyte/io/table/builder.py new file mode 100644 index 0000000..21da615 --- /dev/null +++ b/src/rbyte/io/table/builder.py @@ -0,0 +1,78 @@ +from collections.abc import Hashable +from mmap import ACCESS_READ, mmap +from os import PathLike +from pathlib import Path +from typing import Annotated, override + +import polars as pl +from pydantic import ConfigDict, StringConstraints, validate_call +from structlog import get_logger +from structlog.contextvars import bound_contextvars +from xxhash import xxh3_64_intdigest as digest + +from rbyte.io.table.base import ( + TableBuilderBase, + TableCacheBase, + TableMergerBase, + TableReaderBase, +) + +logger = get_logger(__name__) + + +class TableBuilder(TableBuilderBase): + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + reader: TableReaderBase, + merger: TableMergerBase, + filter: Annotated[str, StringConstraints(strip_whitespace=True)] | None = None, # noqa: A002 + cache: TableCacheBase | None = None, + ) -> None: + super().__init__() + + self._reader = reader + self._merger = merger + self._filter = filter + self._cache = cache + + def _build_cache_key(self, path: PathLike[str]) -> Hashable: + from rbyte import __version__ as rbyte_version # noqa: PLC0415 + + key = [rbyte_version, hash(self._reader), hash(self._merger)] + + if self._filter is not None: + key.append(digest(self._filter)) + + with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f: + key.append(digest(f)) # pyright: ignore[reportArgumentType] + + return tuple(key) + + @override + def build(self, path: PathLike[str]) -> pl.DataFrame: + with bound_contextvars(path=str(path)): + match self._cache: + case TableCacheBase(): + key = self._build_cache_key(path) + if key in self._cache: + logger.debug("reading table from cache") + df = self._cache.get(key) + if df is None: + raise RuntimeError + + return df + + df = self._build(path) + if not self._cache.set(key, df): + logger.warning("failed to cache table") + + return df + + case None: + return self._build(path) + + def _build(self, path: PathLike[str]) -> pl.DataFrame: + dfs = self._reader.read(path) + df = self._merger.merge(dfs) + return df.sql(f"select * from self where ({self._filter or True})") # noqa: S608 diff --git a/src/rbyte/io/table/carla/builder.py b/src/rbyte/io/table/carla/builder.py index e788a0b..0b85fd9 100644 --- a/src/rbyte/io/table/carla/builder.py +++ b/src/rbyte/io/table/carla/builder.py @@ -1,58 +1,54 @@ from collections.abc import Sequence -from functools import cached_property from os import PathLike from pathlib import Path from typing import Annotated, override import polars as pl from polars import selectors as cs -from pydantic import StringConstraints +from pydantic import ConfigDict, StringConstraints, validate_call -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableBuilder +from rbyte.io.table.base import TableBuilderBase from rbyte.io.table.transforms.base import TableTransform -from rbyte.utils.dataframe import unnest_all +from rbyte.utils.dataframe.misc import unnest_all -class CarlaRecordsTableBuilderConfig(BaseModel): - index_column: Annotated[str, StringConstraints(strip_whitespace=True)] = "frame_idx" - select: str | frozenset[str] = "*" - filter: Annotated[str, StringConstraints(strip_whitespace=True)] | None = None - transforms: tuple[HydraConfig[TableTransform], ...] = () - - -class CarlaRecordsTableBuilder(TableBuilder): +class CarlaRecordsTableBuilder(TableBuilderBase): RECORD_KEY = "records" - def __init__(self, config: object) -> None: + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + *, + index_column: Annotated[ + str, StringConstraints(strip_whitespace=True) + ] = "frame_idx", + select: str | frozenset[str] = "*", + filter: Annotated[str, StringConstraints(strip_whitespace=True)] | None = None, # noqa: A002 + transforms: Sequence[TableTransform] = (), + ) -> None: super().__init__() - self._config = CarlaRecordsTableBuilderConfig.model_validate(config) - - @property - def config(self) -> CarlaRecordsTableBuilderConfig: - return self._config - - @cached_property - def _transforms(self) -> Sequence[TableTransform]: - return tuple(transform.instantiate() for transform in self.config.transforms) + self._index_column = index_column + self._select = select + self._filter = filter + self._transforms = transforms @override def build(self, path: PathLike[str]) -> pl.DataFrame: df = pl.read_json(Path(path)).explode(self.RECORD_KEY).unnest(self.RECORD_KEY) df = ( df.select(unnest_all(df.collect_schema())) - .select(self.config.select) + .select(self._select) # 32 bits ought to be enough for anybody .cast({ cs.by_dtype(pl.Int64): pl.Int32, cs.by_dtype(pl.UInt64): pl.UInt32, cs.by_dtype(pl.Float64): pl.Float32, }) - .sql(f"select * from self where ({self.config.filter or True})") # noqa: S608 + .sql(f"select * from self where ({self._filter or True})") # noqa: S608 ) for transform in self._transforms: df = transform(df) - return df.with_row_index(self.config.index_column) + return df.with_row_index(self._index_column) diff --git a/src/rbyte/io/table/mcap/__init__.py b/src/rbyte/io/table/mcap/__init__.py new file mode 100644 index 0000000..e95590c --- /dev/null +++ b/src/rbyte/io/table/mcap/__init__.py @@ -0,0 +1,15 @@ +__all__: list[str] = [] + +try: + from .protobuf_reader import McapProtobufTableReader +except ImportError: + pass +else: + __all__ += ["McapProtobufTableReader"] + +try: + from .ros2_reader import McapRos2TableReader +except ImportError: + pass +else: + __all__ += ["McapRos2TableReader"] diff --git a/src/rbyte/io/table/mcap/config.py b/src/rbyte/io/table/mcap/config.py new file mode 100644 index 0000000..b99593f --- /dev/null +++ b/src/rbyte/io/table/mcap/config.py @@ -0,0 +1,19 @@ +from collections.abc import Mapping + +import polars as pl +from pydantic import ConfigDict, ImportString + +from rbyte.config.base import BaseModel, HydraConfig + +PolarsDataType = pl.DataType | pl.DataTypeClass + + +class Config(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + fields: Mapping[ + str, + Mapping[str, HydraConfig[PolarsDataType] | ImportString[PolarsDataType] | None], + ] + + validate_crcs: bool = False diff --git a/src/rbyte/io/table/mcap/protobuf_reader.py b/src/rbyte/io/table/mcap/protobuf_reader.py new file mode 100644 index 0000000..54ad8f9 --- /dev/null +++ b/src/rbyte/io/table/mcap/protobuf_reader.py @@ -0,0 +1,173 @@ +import json +from collections.abc import Hashable, Mapping +from functools import cached_property +from mmap import ACCESS_READ, mmap +from os import PathLike +from pathlib import Path +from typing import Literal, cast, override + +import more_itertools as mit +import polars as pl +import polars._typing as plt +import pyarrow as pa +from cachetools import cached +from google.protobuf.descriptor_pb2 import FileDescriptorProto, FileDescriptorSet +from google.protobuf.descriptor_pool import DescriptorPool +from google.protobuf.message import Message +from google.protobuf.message_factory import GetMessageClassesForFiles +from mcap.reader import SeekingReader +from mcap.records import Schema +from ptars import HandlerPool +from structlog import get_logger +from structlog.contextvars import bound_contextvars +from tqdm import tqdm +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config.base import HydraConfig +from rbyte.io.table.base import TableReaderBase +from rbyte.utils.dataframe import unnest_all + +from .config import Config + +logger = get_logger(__name__) + + +class McapProtobufTableReader(TableReaderBase, Hashable): + FRAME_INDEX_COLUMN_NAME: Literal["frame_idx"] = "frame_idx" + + def __init__(self, **kwargs: object) -> None: + self._config = Config.model_validate(kwargs) + + @override + def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + with ( + bound_contextvars(path=str(path)), + Path(path).open("rb") as _f, + mmap(fileno=_f.fileno(), length=0, access=ACCESS_READ) as f, + ): + reader = SeekingReader(f, validate_crcs=self._config.validate_crcs) # pyright: ignore[reportArgumentType] + summary = reader.get_summary() + if summary is None: + logger.error(msg := "missing summary") + raise ValueError(msg) + + topics = self.schemas.keys() + if missing_topics := topics - ( + available_topics := {ch.topic for ch in summary.channels.values()} + ): + with bound_contextvars( + missing_topics=sorted(missing_topics), + available_topics=sorted(available_topics), + ): + logger.error(msg := "missing topics") + raise ValueError(msg) + + schemas = { + channel.topic: summary.schemas[channel.schema_id] + for channel in summary.channels.values() + if channel.topic in topics + } + + message_counts = ( + { + channel.topic: stats.channel_message_counts[channel_id] + for channel_id, channel in summary.channels.items() + if channel.topic in topics + } + if (stats := summary.statistics) is not None + else {} + ) + + messages = mit.bucket( + reader.iter_messages(topics), + key=lambda x: x[1].topic, + validator=lambda k: k in topics, + ) + + dfs: Mapping[str, pl.DataFrame] = {} + handler_pool = HandlerPool() + + for topic, schema in self.schemas.items(): + log_time, publish_time, data = mit.unzip( + (msg.log_time, msg.publish_time, msg.data) + for (*_, msg) in tqdm( + messages[topic], + total=message_counts[topic], + postfix={"topic": topic}, + ) + ) + + msg_type = self._get_message_type(schemas[topic]) + handler = handler_pool.get_for_message(msg_type.DESCRIPTOR) # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + + record_batch = cast( + pa.RecordBatch, + handler.list_to_record_batch(list(data)), # pyright: ignore[reportUnknownMemberType] + ) + df = cast(pl.DataFrame, pl.from_arrow(record_batch)) # pyright: ignore[reportUnknownMemberType] + + dfs[topic] = ( + (df.select(unnest_all(df.collect_schema()))) + .hstack([ + pl.Series("log_time", log_time, pl.UInt64), + pl.Series("publish_time", publish_time, pl.UInt64), + ]) + .select(schema) + .cast({k: v for k, v in schema.items() if v is not None}) + ) + + for topic, df in dfs.items(): + match schemas[topic]: + case Schema( + encoding="protobuf", + name="foxglove.CompressedImage" + | "foxglove.CompressedVideo" + | "foxglove.RawImage", + ): + dfs[topic] = df.with_row_index(self.FRAME_INDEX_COLUMN_NAME) + + case _: + pass + + return dfs + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) + + @cached(cache={}, key=lambda _, schema: schema.name) # pyright: ignore[reportUnknownLambdaType, reportUnknownArgumentType, reportUnknownMemberType] + def _get_message_type(self, schema: Schema) -> type[Message]: # noqa: PLR6301 + # inspired by https://github.com/foxglove/mcap/blob/e591defaa95186cef27e37c49fa7e1f0c9f2e8a6/python/mcap-protobuf-support/mcap_protobuf/decoder.py#L29 + fds = FileDescriptorSet.FromString(schema.data) + descriptors = {fd.name: fd for fd in fds.file} + pool = DescriptorPool() + + def _add(fd: FileDescriptorProto) -> None: + for dependency in fd.dependency: + if dependency in descriptors: + _add(descriptors.pop(dependency)) + + pool.Add(fd) # pyright: ignore[reportUnknownMemberType] + + while descriptors: + _add(descriptors.popitem()[1]) + + messages = GetMessageClassesForFiles([fd.name for fd in fds.file], pool) + + return mit.one( + msg_type for name, msg_type in messages.items() if name == schema.name + ) + + @cached_property + def schemas(self) -> dict[str, dict[str, plt.PolarsDataType | None]]: + return { + topic: { + path: leaf.instantiate() if isinstance(leaf, HydraConfig) else leaf + for path, leaf in fields.items() + } + for topic, fields in self._config.fields.items() + } diff --git a/src/rbyte/io/table/mcap/ros2_reader.py b/src/rbyte/io/table/mcap/ros2_reader.py new file mode 100644 index 0000000..2bb5650 --- /dev/null +++ b/src/rbyte/io/table/mcap/ros2_reader.py @@ -0,0 +1,105 @@ +import json +from collections.abc import Hashable, Mapping +from functools import cached_property +from mmap import ACCESS_READ, mmap +from operator import attrgetter +from os import PathLike +from pathlib import Path +from typing import Any, Literal, override + +import polars as pl +import polars._typing as plt +from mcap.reader import SeekingReader +from mcap_ros2.decoder import DecoderFactory +from structlog import get_logger +from structlog.contextvars import bound_contextvars +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config.base import HydraConfig +from rbyte.io.table.base import TableReaderBase + +from .config import Config + +logger = get_logger(__name__) + + +class McapRos2TableReader(TableReaderBase, Hashable): + FRAME_INDEX_COLUMN_NAME: Literal["frame_idx"] = "frame_idx" + + def __init__(self, **kwargs: object) -> None: + self._config = Config.model_validate(kwargs) + + @override + def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + with ( + bound_contextvars(path=str(path)), + Path(path).open("rb") as _f, + mmap(fileno=_f.fileno(), length=0, access=ACCESS_READ) as f, + ): + reader = SeekingReader( + f, # pyright: ignore[reportArgumentType] + validate_crcs=self._config.validate_crcs, + decoder_factories=[DecoderFactory()], + ) + summary = reader.get_summary() + if summary is None: + logger.error(msg := "missing summary") + raise ValueError(msg) + + topics = self._config.fields.keys() + if missing_topics := topics - ( + available_topics := {ch.topic for ch in summary.channels.values()} + ): + with bound_contextvars( + missing_topics=sorted(missing_topics), + available_topics=sorted(available_topics), + ): + logger.error(msg := "missing topics") + raise ValueError(msg) + + getters = { + topic: [attrgetter(field) for field in fields] + for topic, fields in self.schemas.items() + } + + rows: Mapping[str, list[list[Any]]] = {topic: [] for topic in self.schemas} + + for _, channel, msg, msg_decoded in reader.iter_decoded_messages(topics): + topic = channel.topic + row: list[Any] = [] + for getter in getters[topic]: + try: + attr = getter(msg_decoded) + except AttributeError: + attr = getter(msg) + + row.append(attr) + + rows[topic].append(row) + + return { + topic: pl.DataFrame( + data=rows[topic], + schema=self.schemas[topic], # pyright: ignore[reportArgumentType] + orient="row", + ) + for topic in topics + } + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) + + @cached_property + def schemas(self) -> dict[str, dict[str, plt.PolarsDataType | None]]: + return { + topic: { + path: leaf.instantiate() if isinstance(leaf, HydraConfig) else leaf + for path, leaf in fields.items() + } + for topic, fields in self._config.fields.items() + } diff --git a/src/rbyte/io/table/merger.py b/src/rbyte/io/table/merger.py new file mode 100644 index 0000000..3a010c0 --- /dev/null +++ b/src/rbyte/io/table/merger.py @@ -0,0 +1,160 @@ +import json +from collections import OrderedDict +from collections.abc import Hashable, Mapping, Sequence +from functools import cached_property +from operator import itemgetter +from typing import Annotated, Literal, Self, override + +import more_itertools as mit +import polars as pl +from polars.type_aliases import AsofJoinStrategy +from pydantic import StringConstraints, model_validator +from structlog import get_logger +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config.base import BaseModel +from rbyte.io.table.base import TableMergerBase + +logger = get_logger(__name__) + + +class RefColumnMergeConfig(BaseModel): + method: Literal["ref"] = "ref" + + +class AsofColumnMergeConfig(BaseModel): + method: Literal["asof"] = "asof" + tolerance: Annotated[ + str, StringConstraints(strip_whitespace=True, to_lower=True, pattern=r"\d+ms$") + ] = "100ms" + strategy: AsofJoinStrategy = "nearest" + + +class InterpColumnMergeConfig(BaseModel): + method: Literal["interp"] = "interp" + + +MergeConfig = RefColumnMergeConfig | AsofColumnMergeConfig | InterpColumnMergeConfig + + +class Config(BaseModel): + merge: OrderedDict[str, Mapping[str, MergeConfig]] + separator: Annotated[str, StringConstraints(strip_whitespace=True)] = "." + + @model_validator(mode="after") + def validate_refs(self) -> Self: + ref_config = RefColumnMergeConfig() + for k, v in self.columns_by_merge_config.items(): + match v.get(ref_config, None): + case [_column]: + pass + + case _: + msg = f"merge `{k}` must have exactly one column with {ref_config}" + raise ValueError(msg) + + return self + + @cached_property + def columns_by_merge_config( + self, + ) -> Mapping[str, Mapping[MergeConfig, Sequence[str]]]: + return { + k: mit.map_reduce(v.items(), keyfunc=itemgetter(1), valuefunc=itemgetter(0)) + for k, v in self.merge.items() + } + + @cached_property + def ref_columns(self) -> Mapping[str, str]: + return { + k: mit.one(v[RefColumnMergeConfig()]) + for k, v in self.columns_by_merge_config.items() + } + + +class TableMerger(TableMergerBase, Hashable): + def __init__(self, **kwargs: object) -> None: + self._config = Config.model_validate(kwargs) + + def _col_name(self, *args: str) -> str: + return self._config.separator.join(args) + + @override + def merge(self, src: Mapping[str, pl.DataFrame]) -> pl.DataFrame: + dfs = { + k: src[k] + .sort(self._config.ref_columns[k]) + .rename(lambda col, k=k: self._col_name(k, col)) + for k in self._config.merge + } + k_df_ref = mit.first(self._config.merge.keys()) + df_ref = dfs.pop(k_df_ref) + df_ref_col_ref = self._col_name(k_df_ref, self._config.ref_columns[k_df_ref]) + + logger.debug( + "merging", merge_ref=f"{k_df_ref}[{self._config.ref_columns[k_df_ref]}]" + ) + + for k_merge, df_merge in dfs.items(): + cols_by_merge_config = self._config.columns_by_merge_config[k_merge] + df_merge_col_ref = self._col_name( + k_merge, self._config.ref_columns[k_merge] + ) + + for merge_cfg, _df_merge_cols in cols_by_merge_config.items(): + if isinstance(merge_cfg, RefColumnMergeConfig): + continue + + df_merge_cols = tuple( + self._col_name(k_merge, col) for col in _df_merge_cols + ) + + df_ref_height_pre = df_ref.height + match merge_cfg: + case AsofColumnMergeConfig(strategy=strategy, tolerance=tolerance): + df_ref = df_ref.join_asof( + other=df_merge.select(df_merge_col_ref, *df_merge_cols), + left_on=df_ref_col_ref, + right_on=df_merge_col_ref, + strategy=strategy, + tolerance=tolerance, + ).drop_nulls(df_merge_cols) + + case InterpColumnMergeConfig(): + df_ref = ( + # take a union of timestamps + df_ref.join( + df_merge.select(df_merge_col_ref, *df_merge_cols), + how="full", + left_on=df_ref_col_ref, + right_on=df_merge_col_ref, + coalesce=True, + ) + # interpolate + .with_columns( + pl.col(df_merge_cols).interpolate_by(df_ref_col_ref) + ) + # narrow back to original ref col + .join( + df_ref.select(df_ref_col_ref), + on=df_ref_col_ref, + how="semi", + ) + .sort(df_ref_col_ref) + ).drop_nulls(df_merge_cols) + + logger.debug( + "merged", + merge_rows=f"{df_ref_height_pre}->{df_ref.height}", + merge_other=f"{k_merge}[{", ".join(_df_merge_cols)}]", + ) + + return df_ref + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) diff --git a/src/rbyte/io/table/transforms/fps_resampler.py b/src/rbyte/io/table/transforms/fps_resampler.py index 43154a2..5e1635b 100644 --- a/src/rbyte/io/table/transforms/fps_resampler.py +++ b/src/rbyte/io/table/transforms/fps_resampler.py @@ -14,17 +14,16 @@ class FpsResampler(TableTransform): def __init__(self, source_fps: PositiveInt, target_fps: PositiveInt) -> None: super().__init__() - self.source_fps = source_fps - self.target_fps = target_fps - - self.fps_lcm = lcm(source_fps, target_fps) + self._source_fps = source_fps + self._target_fps = target_fps + self._fps_lcm = lcm(source_fps, target_fps) @override def __call__(self, src: pl.DataFrame) -> pl.DataFrame: return ( src.with_row_index(self.IDX_COL) - .with_columns(pl.col(self.IDX_COL) * (self.fps_lcm // self.source_fps)) - .upsample(self.IDX_COL, every=f"{self.fps_lcm // self.target_fps}i") + .with_columns(pl.col(self.IDX_COL) * (self._fps_lcm // self._source_fps)) + .upsample(self.IDX_COL, every=f"{self._fps_lcm // self._target_fps}i") .interpolate() .fill_null(strategy="backward") .drop(self.IDX_COL) diff --git a/src/rbyte/io/table/yaak/__init__.py b/src/rbyte/io/table/yaak/__init__.py index 83b2424..e83fb67 100644 --- a/src/rbyte/io/table/yaak/__init__.py +++ b/src/rbyte/io/table/yaak/__init__.py @@ -1,3 +1,3 @@ -from .builder import YaakMetadataTableBuilder +from .reader import YaakMetadataTableReader -__all__ = ["YaakMetadataTableBuilder"] +__all__ = ["YaakMetadataTableReader"] diff --git a/src/rbyte/io/table/yaak/builder.py b/src/rbyte/io/table/yaak/builder.py deleted file mode 100644 index 0b330d3..0000000 --- a/src/rbyte/io/table/yaak/builder.py +++ /dev/null @@ -1,410 +0,0 @@ -import json -from collections.abc import Iterable, Iterator, Mapping, Sequence -from enum import StrEnum, auto -from functools import cached_property -from hashlib import blake2b -from mmap import ACCESS_READ, mmap -from operator import attrgetter -from os import PathLike -from pathlib import Path -from typing import Annotated, Any, Literal, Self, TypeVar, cast, override - -import polars as pl -from google._upb._message import ( - RepeatedScalarContainer, # noqa: PLC2701 # pyright: ignore[reportUnknownVariableType] -) -from google.protobuf.message import Message -from google.protobuf.timestamp_pb2 import Timestamp -from more_itertools import ( - all_unique, - always_iterable, # pyright: ignore[reportUnknownVariableType] - collapse, # pyright: ignore[reportUnknownVariableType] - map_reduce, -) -from polars.type_aliases import AsofJoinStrategy -from pydantic import ( - Field, - ImportString, - StringConstraints, - field_serializer, - field_validator, - model_validator, -) -from structlog import get_logger -from structlog.contextvars import bound_contextvars - -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableBuilder -from rbyte.utils.dataframe_cache import DataframeDiskCache - -from .message_iterator import YaakMetadataMessageIterator -from .proto import can_pb2, sensor_pb2 - -logger = get_logger(__name__) - -K = TypeVar("K") -V = TypeVar("V") - - -def flatten_mapping(mapping: Mapping[K, V]) -> Iterator[tuple[tuple[K, ...], V]]: - for k, v in mapping.items(): - if isinstance(v, Mapping): - yield from ( - ((k, *always_iterable(_k)), _v) # pyright: ignore[reportUnknownArgumentType] - for _k, _v in flatten_mapping(v) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - ) - else: - yield (k,), v - - -class CameraName(StrEnum): - CAM_FRONT_CENTER = auto() - CAM_FRONT_LEFT = auto() - CAM_FRONT_RIGHT = auto() - CAM_LEFT_FORWARD = auto() - CAM_LEFT_BACKWARD = auto() - CAM_RIGHT_FORWARD = auto() - CAM_RIGHT_BACKWARD = auto() - CAM_REAR = auto() - - -class AsofMergeConfig(BaseModel): - method: Literal["asof"] = "asof" - tolerance: Annotated[ - str, StringConstraints(strip_whitespace=True, to_lower=True, pattern=r"\d+ms$") - ] = "100ms" - strategy: AsofJoinStrategy = "nearest" - - -class InterpMergeConfig(BaseModel): - method: Literal["interp"] = "interp" - - -MergeConfig = AsofMergeConfig | InterpMergeConfig - - -class MetadataFieldSelectConfig(BaseModel): - name: str - merge: MergeConfig | None = None - partition: bool = False - - -class MetadataSelectMessageConfig(BaseModel): - type: ImportString[type[Message]] - alias: Annotated[str, StringConstraints(min_length=1, strip_whitespace=True)] - - -class MetadataSelectConfig(BaseModel): - message: MetadataSelectMessageConfig - fields: tuple[MetadataFieldSelectConfig, ...] = Field(min_length=1) - - @field_validator("fields", mode="before") - @classmethod - def validate_fields( - cls, fields: tuple[MetadataFieldSelectConfig, ...] - ) -> tuple[MetadataFieldSelectConfig, ...]: - if not all_unique(fields, key=attrgetter("name")): - msg = "field names not unique" - raise ValueError(msg) - - return fields - - @model_validator(mode="after") - def validate_message_fields(self) -> Self: - valid = set(self.message.type.DESCRIPTOR.fields_by_name.keys()) - if invalid := {f.name for f in self.fields} - valid: - msg = f"invalid fields for `{self.message.type}`: {invalid}" - raise ValueError(msg) - - return self - - @field_serializer("fields") - @staticmethod - def serialize_fields( - fields: Iterable[MetadataFieldSelectConfig], - ) -> tuple[MetadataFieldSelectConfig, ...]: - return tuple(sorted(fields, key=attrgetter("name"))) - - -class MetadataMergeReferenceConfig(BaseModel): - key: tuple[ImportString[type[Message]]] | tuple[ImportString[type[Message]], str] - column: str - - -class MetadataMergeConfig(BaseModel): - reference: MetadataMergeReferenceConfig - - -class MetadataTableBuilderConfig(BaseModel): - cameras: frozenset[CameraName] = Field(default=frozenset(CameraName)) - select: tuple[MetadataSelectConfig, ...] = Field(min_length=1) - merge: MetadataMergeConfig - filter: Annotated[str, StringConstraints(strip_whitespace=True)] | None = None - cache: HydraConfig[DataframeDiskCache] | None = None - - @field_validator("select") - @classmethod - def validate_select( - cls, select: tuple[MetadataSelectConfig, ...] - ) -> tuple[MetadataSelectConfig, ...]: - if not all_unique(select, key=attrgetter("message.type")): - msg = "select message paths not unique" - raise ValueError(msg) - - if not all_unique(select, key=attrgetter("message.alias")): - msg = "select message aliases not unique" - raise ValueError(msg) - - return select - - @field_serializer("cameras") - @staticmethod - def serialize_cameras(cameras: Iterable[CameraName]) -> tuple[CameraName, ...]: - return tuple(sorted(cameras)) - - @field_serializer("select") - @staticmethod - def serialize_select( - select: Iterable[MetadataSelectConfig], - ) -> tuple[MetadataSelectConfig, ...]: - return tuple(sorted(select, key=attrgetter("message.alias"))) - - -class YaakMetadataTableBuilder(TableBuilder): - COLUMN_NAME_SEPARATOR = "." - SCHEMA_OVERRIDES: Mapping[type[Message], pl._typing.SchemaDict] = { - sensor_pb2.ImageMetadata: { - "time_stamp": pl.Datetime("us"), - "camera_name": pl.Enum(tuple(CameraName)), - }, - sensor_pb2.Gnss: {"time_stamp": pl.Datetime("us")}, - can_pb2.VehicleMotion: {"time_stamp": pl.Datetime("us")}, - can_pb2.VehicleState: {"time_stamp": pl.Datetime("us")}, - } - - def __init__(self, config: object) -> None: - super().__init__() - - self.config = MetadataTableBuilderConfig.model_validate(config) - - @cached_property - def _dataframe_cache(self) -> DataframeDiskCache | None: - return cfg.instantiate() if (cfg := self.config.cache) is not None else None - - def _build_dataframe_cache_key(self, path: PathLike[str]) -> tuple[str, str]: - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_json = self.config.model_dump_json(exclude={"cache"}) - config_dict = json.loads(config_json) - config_json_sorted = json.dumps(config_dict, sort_keys=True) - config_hash = blake2b(config_json_sorted.encode("utf-8")).hexdigest() - - with ( - Path(path).open("rb") as f, - mmap(f.fileno(), 0, access=ACCESS_READ) as file, - ): - file_hash = blake2b(file).hexdigest() - - return (config_hash, file_hash) - - @override - def build(self, path: PathLike[str]) -> pl.DataFrame: - with bound_contextvars(path=str(path)): - match self._dataframe_cache: - case None: - return self._build_dataframe(path) - - case _: - key = self._build_dataframe_cache_key(path) - match df := self._dataframe_cache.get(key): - case None: - logger.debug("dataframe cache miss") - df = self._build_dataframe(path) - if not self._dataframe_cache.set(key, df): - logger.warning("failed to cache dataframe") - - return df - - case _: - logger.debug("dataframe cache hit") - return df - - def _build_dataframe(self, path: PathLike[str]) -> pl.DataFrame: - logger.debug("building dataframe") - - dfs = self._read_message_dataframes(path) - dfs = self._partition_message_dataframes(dfs) - df = self._merge_message_dataframes(dfs) - - return df.sql( - f"select * from self where ({self.config.filter or True})" # noqa: S608 - ) - - def _read_message_dataframes( - self, path: PathLike[str] - ) -> Mapping[type[Message], pl.DataFrame]: - message_rows: Mapping[type[Message], list[list[Message]]] = { - m: [] for m in self._select_fields - } - - with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f: - messages = YaakMetadataMessageIterator(f, message_types=message_rows.keys()) - for message in messages: - # PERF: avoid including messages from cameras we don't want - if ( - isinstance(message, sensor_pb2.ImageMetadata) - and message.camera_name not in self.config.cameras - ): - continue - - msg_type = type(message) - - row: list[Any] = [] - for field in self._select_fields[msg_type]: - match attr := getattr(message, field): - case Timestamp(): - row.append(attr.ToMicroseconds()) - - case RepeatedScalarContainer(): - row.append(tuple(attr)) # pyright: ignore[reportUnknownArgumentType] - - case _: - row.append(attr) - - message_rows[msg_type].append(row) - - message_dfs: Mapping[type[Message], pl.DataFrame] = { - msg_type: pl.DataFrame( - data=rows, - schema=self._select_fields[msg_type], - schema_overrides=self.SCHEMA_OVERRIDES.get(msg_type, None), - orient="row", - ) - for msg_type, rows in message_rows.items() - } - - return message_dfs - - def _partition_message_dataframes( - self, message_dfs: Mapping[type[Message], pl.DataFrame] - ) -> Mapping[ - type[Message], pl.DataFrame | Mapping[tuple[object, ...], pl.DataFrame] - ]: - return { - msg_type: df - if (by := self._partition_fields.get(msg_type, None)) is None - else df.partition_by(*by, include_key=False, as_dict=True) - for msg_type, df in message_dfs.items() - } - - def _merge_message_dataframes( - self, - message_dfs: Mapping[ - type[Message], pl.DataFrame | Mapping[tuple[object, ...], pl.DataFrame] - ], - ) -> pl.DataFrame: - ref_df_key, ref_col = ( - self.config.merge.reference.key, - self.config.merge.reference.column, - ) - - dfs = { - (merge_df_key := tuple(collapse(_k))): cast(pl.DataFrame, v) - .sort(ref_col) - .rename(lambda col: self._col_name(merge_df_key, col)) - for _k, v in flatten_mapping(message_dfs) - } - - try: - ref_df = dfs.pop(ref_df_key) - except KeyError as e: - msg = f"invalid reference key {ref_df_key}, valid: {list(dfs)}" - raise ValueError(msg) from e - - ref_df_ref_col = self._col_name(*ref_df_key, ref_col) - - for merge_df_key, merge_df in dfs.items(): - msg_type, *_ = merge_df_key - merge_df_ref_col = self._col_name(merge_df_key, ref_col) - - for merge_config, merge_fields in self._merge_fields[msg_type].items(): - merge_df_cols = tuple( - self._col_name(merge_df_key, field) for field in merge_fields - ) - - match merge_config: - case AsofMergeConfig(strategy=strategy, tolerance=tolerance): - ref_df = ref_df.join_asof( - merge_df.select(merge_df_ref_col, *merge_df_cols), - left_on=ref_df_ref_col, - right_on=merge_df_ref_col, - strategy=strategy, - tolerance=tolerance, - ).drop_nulls() - - case InterpMergeConfig(): - ref_df = ( - # take a union of timestamps - ref_df.join( - merge_df.select(merge_df_ref_col, *merge_df_cols), - how="full", - left_on=ref_df_ref_col, - right_on=merge_df_ref_col, - coalesce=True, - ) - # interpolate - .with_columns( - pl.col(merge_df_cols).interpolate_by(ref_df_ref_col) - ) - # narrow back to original ref col - .join( - ref_df.select(ref_df_ref_col), - on=ref_df_ref_col, - how="semi", - ) - .sort(ref_df_ref_col) - ) - - return ref_df - - @cached_property - def _message_type_aliases(self) -> Mapping[type[Message], str]: - return { - select.message.type: select.message.alias for select in self.config.select - } - - @cached_property - def _select_fields(self) -> Mapping[type[Message], Sequence[str]]: - return { - select.message.type: tuple(field.name for field in select.fields) - for select in self.config.select - } - - @cached_property - def _merge_fields( - self, - ) -> Mapping[type[Message], Mapping[MergeConfig, Sequence[str]]]: - return { - select.message.type: map_reduce( - (field for field in select.fields if field.merge is not None), - keyfunc=attrgetter("merge"), - valuefunc=attrgetter("name"), - ) - for select in self.config.select - } - - @cached_property - def _partition_fields(self) -> Mapping[type[Message], Sequence[str]]: - return { - select.message.type: field_names - for select in self.config.select - if ( - field_names := tuple( - field.name for field in select.fields if field.partition - ) - ) - } - - def _col_name(self, *args: object) -> str: - return self.COLUMN_NAME_SEPARATOR.join( - self._message_type_aliases.get(arg, arg) for arg in collapse(args) - ) diff --git a/src/rbyte/io/table/yaak/idl-repo b/src/rbyte/io/table/yaak/idl-repo index 4705b08..7247555 160000 --- a/src/rbyte/io/table/yaak/idl-repo +++ b/src/rbyte/io/table/yaak/idl-repo @@ -1 +1 @@ -Subproject commit 4705b081cfa249b104b6622abb883d3197eba007 +Subproject commit 7247555a0bbfb98dbafa91766511773cb26141ad diff --git a/src/rbyte/io/table/yaak/message_iterator.py b/src/rbyte/io/table/yaak/message_iterator.py index 2a41ecd..4c9203f 100644 --- a/src/rbyte/io/table/yaak/message_iterator.py +++ b/src/rbyte/io/table/yaak/message_iterator.py @@ -5,6 +5,7 @@ from google.protobuf.message import Message from structlog import get_logger +from structlog.contextvars import bound_contextvars from .proto import can_pb2, sensor_pb2 @@ -15,7 +16,7 @@ def to_uint32(buf: bytes) -> int: return int(struct.unpack("I", buf)[0]) -class YaakMetadataMessageIterator(Iterator[Message]): +class YaakMetadataMessageIterator(Iterator[tuple[type[Message], bytes]]): """An iterator over a metadata file(-like object) producing messages.""" MESSAGE_TYPES: Mapping[int, type[Message]] = { @@ -36,14 +37,14 @@ def __init__( ) -> None: super().__init__() - self.file = file + self._file = file for expected_val, desc in ( (self.FILE_HEADER_LEN, "file header length"), (self.FILE_HEADER_VERSION, "file header version"), (self.MESSAGE_HEADER_LEN, "message header length"), ): - if (val := to_uint32(self.file.read(4))) != expected_val: + if (val := to_uint32(self._file.read(4))) != expected_val: msg = f"invalid {desc}: {val}, expected: {expected_val}" raise ValueError(msg) @@ -53,8 +54,9 @@ def __init__( if unknown_message_types := set(message_types) - set( self.MESSAGE_TYPES.values() ): - msg = f"unknown: {unknown_message_types}" - raise ValueError(msg) + with bound_contextvars(unknown_message_types=unknown_message_types): + logger.error(msg := "unknown message types") + raise ValueError(msg) self._message_types = { k: v for k, v in self.MESSAGE_TYPES.items() if v in message_types @@ -65,7 +67,7 @@ def __iter__(self) -> Self: return self @override - def __next__(self) -> Message: + def __next__(self) -> tuple[type[Message], bytes]: msg_type = None while msg_type is None: try: @@ -80,21 +82,15 @@ def __next__(self) -> Message: msg_type_idx, msg_buf = msg_data msg_type = self._message_types.get(msg_type_idx) - try: - message = msg_type.FromString(msg_buf) # pyright: ignore[reportPossiblyUnboundVariable] - except Exception as exc: - logger.warning("failed to parse message", type=msg_type) - raise StopIteration from exc - - return message + return msg_type, msg_buf # pyright: ignore[reportPossiblyUnboundVariable] def _read_message(self) -> tuple[int, bytes] | None: - msg__type_idx_buf = self.file.read(4) + msg__type_idx_buf = self._file.read(4) if not msg__type_idx_buf: return None msg_type_idx = to_uint32(msg__type_idx_buf) - msg_len = to_uint32(self.file.read(4)) - msg_buf = self.file.read(msg_len) + msg_len = to_uint32(self._file.read(4)) + msg_buf = self._file.read(msg_len) return (msg_type_idx, msg_buf) diff --git a/src/rbyte/io/table/yaak/reader.py b/src/rbyte/io/table/yaak/reader.py new file mode 100644 index 0000000..9bfa1bb --- /dev/null +++ b/src/rbyte/io/table/yaak/reader.py @@ -0,0 +1,99 @@ +import json +from collections.abc import Mapping +from functools import cached_property +from mmap import ACCESS_READ, mmap +from operator import itemgetter +from os import PathLike +from pathlib import Path +from typing import cast, override + +import more_itertools as mit +import polars as pl +from google.protobuf.message import Message +from optree import tree_map +from ptars import HandlerPool +from pydantic import ConfigDict, ImportString +from structlog import get_logger +from tqdm import tqdm +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config.base import BaseModel, HydraConfig +from rbyte.io.table.base import TableReaderBase + +from .message_iterator import YaakMetadataMessageIterator +from .proto import sensor_pb2 + +logger = get_logger(__name__) + + +PolarsDataType = pl.DataType | pl.DataTypeClass + + +class Config(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + fields: Mapping[ + ImportString[type[Message]], + Mapping[str, HydraConfig[PolarsDataType] | ImportString[PolarsDataType] | None], + ] + + +class YaakMetadataTableReader(TableReaderBase): + def __init__(self, **kwargs: object) -> None: + super().__init__() + + self._config = Config.model_validate(kwargs) + + @override + def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f: + handler_pool = HandlerPool() + + messages = mit.bucket( + YaakMetadataMessageIterator(f, message_types=self._fields), + key=itemgetter(0), + validator=self._fields.__contains__, + ) + + dfs = { + msg_type.__name__: cast( + pl.DataFrame, + pl.from_arrow( # pyright: ignore[reportUnknownMemberType] + data=handler_pool.get_for_message(msg_type.DESCRIPTOR) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] + .list_to_record_batch([ + msg_data + for (_, msg_data) in tqdm( + messages[msg_type], postfix={"msg_type": msg_type} + ) + ]) + .select(schema), + schema=schema, # pyright: ignore[reportArgumentType] + ), + ) + for msg_type, schema in self._fields.items() + } + + if (df := dfs.pop((k := sensor_pb2.ImageMetadata.__name__), None)) is not None: + dfs |= { + ".".join((k, *map(str, k_partition))): df_partition + for k_partition, df_partition in df.partition_by( + "camera_name", include_key=False, as_dict=True + ).items() + } + + return dfs + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) + + @cached_property + def _fields(self) -> Mapping[type[Message], Mapping[str, PolarsDataType | None]]: + return tree_map( # pyright: ignore[reportUnknownVariableType, reportReturnType] + lambda x: x.instantiate() if isinstance(x, HydraConfig) else x, # pyright: ignore[reportUnknownLambdaType, reportUnknownArgumentType] + self._config.fields, # pyright: ignore[reportArgumentType] + ) diff --git a/src/rbyte/sample/builder.py b/src/rbyte/sample/builder.py index b4f8664..7178479 100644 --- a/src/rbyte/sample/builder.py +++ b/src/rbyte/sample/builder.py @@ -4,35 +4,31 @@ import polars as pl from pydantic import PositiveInt, StringConstraints, validate_call -from rbyte.config.base import BaseModel - from .base import SampleTableBuilder -class GreedySampleTableBuilderConfig(BaseModel): - index_column: str - length: PositiveInt - min_step: PositiveInt - stride: PositiveInt = 1 - filter: ( - Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)] | None - ) = None - - class GreedySampleTableBuilder(SampleTableBuilder): @validate_call - def __init__(self, config: object) -> None: + def __init__( + self, + index_column: str, + length: PositiveInt = 1, + min_step: PositiveInt = 1, + stride: PositiveInt = 1, + filter: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)] # noqa: A002 + | None = None, + ) -> None: super().__init__() - self._config = GreedySampleTableBuilderConfig.model_validate(config) - - @property - def config(self) -> GreedySampleTableBuilderConfig: - return self._config + self._index_column = index_column + self._length = length + self._min_step = min_step + self._stride = stride + self._filter = filter @override def build(self, source: pl.LazyFrame) -> pl.LazyFrame: - idx_col = self.config.index_column + idx_col = self._index_column idx_dtype = source.select(idx_col).collect_schema()[idx_col] return ( @@ -40,15 +36,15 @@ def build(self, source: pl.LazyFrame) -> pl.LazyFrame: pl.int_range( pl.col(idx_col).min().fill_null(value=0), pl.col(idx_col).max().fill_null(value=0) + 1, - step=self.config.min_step, + step=self._min_step, dtype=idx_dtype, # pyright: ignore[reportArgumentType] ) ) .select( pl.int_ranges( pl.col(idx_col), - pl.col(idx_col) + self.config.length * self.config.stride, - self.config.stride, + pl.col(idx_col) + self._length * self._stride, + self._stride, dtype=idx_dtype, # pyright: ignore[reportArgumentType] ) ) @@ -57,8 +53,10 @@ def build(self, source: pl.LazyFrame) -> pl.LazyFrame: .join(source, on=idx_col, how="inner") .group_by(sample_idx_col) .all() - .filter(pl.col(idx_col).list.len() == self.config.length) - .sql(f"select * from self where ({self.config.filter or True})") # noqa: S608 + .filter(pl.col(idx_col).list.len() == self._length) + .sql(f"select * from self where ({self._filter or True})") # noqa: S608 .sort(sample_idx_col) - .select(pl.exclude(sample_idx_col).list.to_array(self.config.length)) + .select(pl.exclude(sample_idx_col)) + # TODO: https://github.com/pola-rs/polars/issues/18810 # noqa: FIX002 + # .select(pl.all().list.to_array(self.length)) ) diff --git a/src/rbyte/scripts/build_table.py b/src/rbyte/scripts/build_table.py index df20386..c33a14c 100644 --- a/src/rbyte/scripts/build_table.py +++ b/src/rbyte/scripts/build_table.py @@ -4,28 +4,20 @@ import hydra from hydra.utils import instantiate from omegaconf import DictConfig -from polars import DataFrame from structlog import get_logger -from rbyte.io.table.base import TableBuilder +from rbyte.io.table.base import Table, TableBuilderBase logger = get_logger(__name__) -def run(config: DictConfig) -> None: - builder = cast(TableBuilder, instantiate(config.builder)) - writer = cast(Callable[[DataFrame], None], instantiate(config.writer)) - df = builder.build(config.path) - - return writer(df) - - @hydra.main(version_base=None) def main(config: DictConfig) -> None: - try: - run(config) - except Exception: - logger.exception("failed") + table_builder = cast(TableBuilderBase, instantiate(config.table_builder)) + writer = cast(Callable[[Table], None], instantiate(config.writer)) + table = table_builder.build(config.path) + + return writer(table) if __name__ == "__main__": diff --git a/src/rbyte/scripts/read_frames.py b/src/rbyte/scripts/read_frames.py new file mode 100644 index 0000000..dec71c2 --- /dev/null +++ b/src/rbyte/scripts/read_frames.py @@ -0,0 +1,61 @@ +from typing import cast + +import hydra +import more_itertools as mit +import rerun as rr +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig +from structlog import get_logger +from tqdm import tqdm + +from rbyte.io.frame.base import FrameReader + +logger = get_logger(__name__) + + +@hydra.main(version_base=None) +def main(config: DictConfig) -> None: + frame_reader = cast(FrameReader, instantiate(config.frame_reader)) + rr.init("rbyte", spawn=True) + + for frame_indexes in mit.chunked( + tqdm(sorted(frame_reader.get_available_indexes())), + config.batch_size, + strict=False, + ): + frames = frame_reader.read(frame_indexes) + match frames.shape, frames.dtype: + case ((_, height, width, 3), torch.uint8): + rr.log( + config.entity_path, + [ + rr.components.ImageFormat( + height=height, + width=width, + color_model="RGB", + channel_datatype="U8", + ), + rr.Image.indicator(), + ], + static=True, + strict=True, + ) + + rr.send_columns( + config.entity_path, + times=[rr.TimeSequenceColumn("frame_index", frame_indexes)], + components=[ + rr.components.ImageBufferBatch( + frames.flatten(start_dim=1, end_dim=-1).numpy() # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + ) + ], + strict=True, + ) + + case _: + raise NotImplementedError + + +if __name__ == "__main__": + main() diff --git a/src/rbyte/scripts/visualize.py b/src/rbyte/scripts/visualize.py index 8f1a3c3..acac24f 100644 --- a/src/rbyte/scripts/visualize.py +++ b/src/rbyte/scripts/visualize.py @@ -14,7 +14,8 @@ logger = get_logger(__name__) -def run(config: DictConfig) -> None: +@hydra.main(version_base=None) +def main(config: DictConfig) -> None: logger = cast(Logger[Any], instantiate(config.logger)) dataloader = cast(DataLoader[Any], instantiate(config.dataloader)) @@ -25,13 +26,5 @@ def run(config: DictConfig) -> None: logger.log(batch_idx, batch) -@hydra.main(version_base=None) -def main(config: DictConfig) -> None: - try: - run(config) - except Exception: - logger.exception("failed") - - if __name__ == "__main__": main() diff --git a/src/rbyte/utils/__init__.py b/src/rbyte/utils/__init__.py index 50c54cc..e69de29 100644 --- a/src/rbyte/utils/__init__.py +++ b/src/rbyte/utils/__init__.py @@ -1,3 +0,0 @@ -from .dataframe_cache import DataframeDiskCache - -__all__ = ["DataframeDiskCache"] diff --git a/src/rbyte/utils/dataframe/__init__.py b/src/rbyte/utils/dataframe/__init__.py new file mode 100644 index 0000000..b12763f --- /dev/null +++ b/src/rbyte/utils/dataframe/__init__.py @@ -0,0 +1,4 @@ +from .cache import DataframeDiskCache +from .misc import unnest_all + +__all__ = ["DataframeDiskCache", "unnest_all"] diff --git a/src/rbyte/utils/dataframe_cache.py b/src/rbyte/utils/dataframe/cache.py similarity index 73% rename from src/rbyte/utils/dataframe_cache.py rename to src/rbyte/utils/dataframe/cache.py index 771946e..359981b 100644 --- a/src/rbyte/utils/dataframe_cache.py +++ b/src/rbyte/utils/dataframe/cache.py @@ -1,14 +1,16 @@ from collections.abc import Hashable from io import BufferedReader from tempfile import TemporaryFile -from typing import Literal +from typing import Literal, override import polars as pl from diskcache import Cache from pydantic import ByteSize, DirectoryPath, NewPath, validate_call +from rbyte.io.table.base import TableCacheBase -class DataframeDiskCache: + +class DataframeDiskCache(TableCacheBase): @validate_call def __init__( self, directory: DirectoryPath | NewPath, size_limit: ByteSize | None = None @@ -16,6 +18,11 @@ def __init__( super().__init__() self._cache = Cache(directory=directory, size_limit=size_limit) + @override + def __contains__(self, key: Hashable) -> bool: + return key in self._cache + + @override def get(self, key: Hashable) -> pl.DataFrame | None: match val := self._cache.get(key, default=None, read=True): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] case BufferedReader(): @@ -27,8 +34,9 @@ def get(self, key: Hashable) -> pl.DataFrame | None: case _: # pyright: ignore[reportUnknownVariableType] raise NotImplementedError - def set(self, key: Hashable, dataframe: pl.DataFrame) -> Literal[True]: + @override + def set(self, key: Hashable, value: pl.DataFrame) -> Literal[True]: with TemporaryFile() as f: - dataframe.write_ipc(f, compression="uncompressed") + value.write_ipc(f, compression="uncompressed") f.seek(0) # pyright: ignore[reportUnusedCallResult] return self._cache.set(key, f, read=True) # pyright: ignore[reportUnknownMemberType] diff --git a/src/rbyte/utils/dataframe.py b/src/rbyte/utils/dataframe/misc.py similarity index 66% rename from src/rbyte/utils/dataframe.py rename to src/rbyte/utils/dataframe/misc.py index 301cd36..d9fc8b2 100644 --- a/src/rbyte/utils/dataframe.py +++ b/src/rbyte/utils/dataframe/misc.py @@ -1,15 +1,16 @@ from collections.abc import Generator, Mapping import polars as pl -from polars._typing import PolarsDataType +import polars._typing as plt +# TODO: https://github.com/pola-rs/polars/issues/12353 # noqa: FIX002 def unnest_all( - schema: Mapping[str, PolarsDataType], separator: str = "." + schema: Mapping[str, plt.PolarsDataType], separator: str = "." ) -> Generator[pl.Expr]: def _unnest( - schema: Mapping[str, PolarsDataType], path: tuple[str, ...] = () - ) -> Generator[tuple[tuple[str, ...], PolarsDataType]]: + schema: Mapping[str, plt.PolarsDataType], path: tuple[str, ...] = () + ) -> Generator[tuple[tuple[str, ...], plt.PolarsDataType]]: for name, dtype in schema.items(): match dtype: case pl.Struct(): diff --git a/src/rbyte/viz/loggers/rerun_logger.py b/src/rbyte/viz/loggers/rerun_logger.py index 7156371..b164f6c 100644 --- a/src/rbyte/viz/loggers/rerun_logger.py +++ b/src/rbyte/viz/loggers/rerun_logger.py @@ -1,61 +1,33 @@ -from collections.abc import Callable, Mapping +from collections.abc import Mapping from functools import cache, cached_property -from typing import Any, Literal, cast, override +from typing import Any, Literal, Protocol, cast, override, runtime_checkable import rerun as rr import torch -from pydantic import Field, ImportString +from optree import tree_flatten_with_path +from pydantic import ImportString, validate_call from rerun._baseclasses import ComponentBatchMixin # noqa: PLC2701 +from rerun._send_columns import TimeColumnLike # noqa: PLC2701 from torch import Tensor -from torch.utils._pytree import tree_flatten_with_path # noqa: PLC2701 from rbyte.batch import Batch -from rbyte.config.base import BaseModel from .base import Logger -type NestedKey = str | tuple[str, ...] -TimeColumn = rr.TimeSequenceColumn | rr.TimeNanosColumn | rr.TimeSecondsColumn - -class SchemaItemConfig(BaseModel): - key: NestedKey - type: Callable[[str, int], None] | Callable[[str, float], None] - - -class TransformConfig(BaseModel): - select: tuple[NestedKey, ...] - apply: Callable[[Tensor], Tensor] - - -class RerunLoggerConfig(BaseModel): - log_schema: Mapping[Literal["frame", "table"], Mapping[str, ImportString[Any]]] - transforms: list[TransformConfig] = Field(default_factory=list) - - @cached_property - def times(self) -> tuple[tuple[tuple[str, ...], type[TimeColumn]], ...]: - return tuple( # pyright: ignore[reportUnknownVariableType] - (tuple(x.key for x in path), leaf) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] - for path, leaf in tree_flatten_with_path(self.log_schema)[0] - if issubclass(leaf, TimeColumn) - ) - - @cached_property - def components( - self, - ) -> tuple[tuple[tuple[str, ...], type[ComponentBatchMixin]], ...]: - return tuple( # pyright: ignore[reportUnknownVariableType] - (tuple(x.key for x in path), leaf) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] - for path, leaf in tree_flatten_with_path(self.log_schema)[0] - if issubclass(leaf, ComponentBatchMixin) - ) +@runtime_checkable +class TimeColumn(TimeColumnLike, Protocol): ... class RerunLogger(Logger[Batch]): - def __init__(self, config: object) -> None: + @validate_call + def __init__( + self, + schema: Mapping[Literal["frame", "table"], Mapping[str, ImportString[Any]]], + ) -> None: super().__init__() - self.config = RerunLoggerConfig.model_validate(config) + self._schema = schema @cache # noqa: B019 def _get_recording(self, *, application_id: str) -> rr.RecordingStream: # noqa: PLR6301 @@ -64,25 +36,20 @@ def _get_recording(self, *, application_id: str) -> rr.RecordingStream: # noqa: ) @override - def log(self, batch_idx: int, batch: Batch) -> None: # pyright: ignore[reportGeneralTypeIssues, reportUnknownParameterType] - for transform in self.config.transforms: - batch = batch.update( # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - batch.select(*transform.select).apply(transform.apply) # pyright: ignore[reportUnknownMemberType] - ) - + def log(self, batch_idx: int, batch: Batch) -> None: # NOTE: zip because batch.meta.input_id is NonTensorData and isn't indexed for input_id, sample in zip( # pyright: ignore[reportUnknownVariableType] - batch.get(k := ("meta", "input_id")), # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType] - batch.exclude(k), # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType] + batch.get(k := ("meta", "input_id")), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] + batch.exclude(k), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] strict=True, ): with self._get_recording(application_id=input_id): # pyright: ignore[reportUnknownArgumentType] times = [ - fn(times=sample.get(k).numpy(), timeline="/".join(k)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] - for k, fn in self.config.times + fn(timeline="/".join(k), times=sample.get(k).numpy()) # pyright: ignore[reportUnknownMemberType, reportCallIssue] + for k, fn in self.times ] - for k, fn in self.config.components: + for k, fn in self.components: path = "/".join(k) tensor = cast(Tensor, sample.get(k)) # pyright: ignore[reportUnknownMemberType] match fn: @@ -120,3 +87,25 @@ def log(self, batch_idx: int, batch: Batch) -> None: # pyright: ignore[reportGe components=[fn(tensor.numpy())], # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportCallIssue] strict=True, ) + + @cached_property + def times(self) -> tuple[tuple[tuple[str, ...], type[TimeColumn]], ...]: + paths, leaves, _ = tree_flatten_with_path(self._schema) # pyright: ignore[reportArgumentType, reportUnknownVariableType] + + return tuple( + (path, leaf) + for path, leaf in zip(paths, leaves, strict=True) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] + if issubclass(leaf, TimeColumn) + ) + + @cached_property + def components( + self, + ) -> tuple[tuple[tuple[str, ...], type[ComponentBatchMixin]], ...]: + paths, leaves, _ = tree_flatten_with_path(self._schema) # pyright: ignore[reportArgumentType, reportUnknownVariableType] + + return tuple( + (path, leaf) + for path, leaf in zip(paths, leaves, strict=True) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] + if issubclass(leaf, ComponentBatchMixin) + )