Skip to content

Commit

Permalink
feat: add pipefunc cache example (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov authored Nov 28, 2024
1 parent cc5b915 commit 304b5ef
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 48 deletions.
89 changes: 49 additions & 40 deletions config/_templates/dataset/yaak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,50 +30,23 @@ inputs:
pipeline:
_target_: pipefunc.Pipeline
validate_type_annotations: false
cache_type: disk
cache_kwargs:
cache_dir: /tmp/rbyte-cache
functions:
- _target_: pipefunc.PipeFunc
bound:
path: ${data_dir}/(@=input_id@)/metadata.log
output_name: meta_data
scope: meta
output_name: data
cache: true
func:
_target_: rbyte.io.YaakMetadataDataFrameBuilder
fields:
rbyte.io.yaak.proto.sensor_pb2.ImageMetadata:
time_stamp:
_target_: polars.Datetime
time_unit: ns

frame_idx:
_target_: polars.Int32

camera_name:
_target_: polars.Enum
categories:
- cam_front_center
- cam_front_left
- cam_front_right
- cam_left_forward
- cam_right_forward
- cam_left_backward
- cam_right_backward
- cam_rear

rbyte.io.yaak.proto.can_pb2.VehicleMotion:
time_stamp:
_target_: polars.Datetime
time_unit: ns

speed:
_target_: polars.Float32

gear:
_target_: polars.Enum
categories: ["0", "1", "2", "3"]
_target_: hydra.utils.get_method
path: rbyte.io.build_yaak_metadata_dataframe

- _target_: pipefunc.PipeFunc
scope: mcap
bound:
path: ${data_dir}/(@=input_id@)/ai.mcap
output_name: mcap_data
output_name: data
func:
_target_: rbyte.io.McapDataFrameBuilder
decoder_factories: [rbyte.utils._mcap.ProtobufDecoderFactory]
Expand All @@ -94,8 +67,8 @@ inputs:
k0: meta
k1: mcap
renames:
v0: meta_data
v1: mcap_data
v0: meta.data
v1: mcap.data
output_name: data

- _target_: pipefunc.PipeFunc
Expand Down Expand Up @@ -150,7 +123,7 @@ inputs:
func:
_target_: rbyte.io.DataFrameFilter
predicate: |
`meta/VehicleMotion/gear` == '3'
`meta/VehicleMotion/speed` > 44
- _target_: pipefunc.PipeFunc
renames:
Expand All @@ -170,4 +143,40 @@ inputs:
_target_: rbyte.io.DataFrameFilter
predicate: |
array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6
kwargs:
meta:
path: ${data_dir}/(@=input_id@)/metadata.log
fields:
rbyte.io.yaak.proto.sensor_pb2.ImageMetadata:
time_stamp:
_target_: polars.Datetime
time_unit: ns

frame_idx:
_target_: polars.Int32

camera_name:
_target_: polars.Enum
categories:
- cam_front_center
- cam_front_left
- cam_front_right
- cam_left_forward
- cam_right_forward
- cam_left_backward
- cam_right_backward
- cam_rear

rbyte.io.yaak.proto.can_pb2.VehicleMotion:
time_stamp:
_target_: polars.Datetime
time_unit: ns

speed:
_target_: polars.Float32

gear:
_target_: polars.Enum
categories: ["0", "1", "2", "3"]
#@ end
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rbyte"
version = "0.9.0"
version = "0.9.1"
description = "Multimodal PyTorch dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
Expand All @@ -20,7 +20,7 @@ dependencies = [
"structlog>=24.4.0",
"xxhash>=3.5.0",
"tqdm>=4.66.5",
"pipefunc>=0.40.2",
"pipefunc>=0.41.0",
]
readme = "README.md"
requires-python = ">=3.12,<3.13"
Expand Down
6 changes: 4 additions & 2 deletions src/rbyte/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import polars as pl
import torch
from hydra.utils import instantiate
from pipefunc import Pipeline
from pydantic import Field, StringConstraints, validate_call
from structlog import get_logger
Expand Down Expand Up @@ -69,9 +70,10 @@ def __init__(
output_name = (
samples_cfg.output_name or pipeline.unique_leaf_node.output_name # pyright: ignore[reportUnknownMemberType]
)
samples[input_id] = pipeline.run(
output_name=output_name, kwargs=samples_cfg.kwargs
kwargs = instantiate(
samples_cfg.kwargs, _recursive_=True, _convert_="all"
)
samples[input_id] = pipeline.run(output_name=output_name, kwargs=kwargs)
logger.debug(
"built samples",
columns=samples[input_id].columns,
Expand Down
4 changes: 2 additions & 2 deletions src/rbyte/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
__all__ += ["FfmpegFrameSource"]

try:
from .yaak import YaakMetadataDataFrameBuilder
from .yaak import YaakMetadataDataFrameBuilder, build_yaak_metadata_dataframe
except ImportError:
pass
else:
__all__ += ["YaakMetadataDataFrameBuilder"]
__all__ += ["YaakMetadataDataFrameBuilder", "build_yaak_metadata_dataframe"]
7 changes: 5 additions & 2 deletions src/rbyte/io/yaak/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .dataframe_builder import YaakMetadataDataFrameBuilder
from .dataframe_builder import (
YaakMetadataDataFrameBuilder,
build_yaak_metadata_dataframe,
)

__all__ = ["YaakMetadataDataFrameBuilder"]
__all__ = ["YaakMetadataDataFrameBuilder", "build_yaak_metadata_dataframe"]
7 changes: 7 additions & 0 deletions src/rbyte/io/yaak/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,10 @@ def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]:
}

return dfs


# exposing all kwargs so its cacheable by pipefunc
def build_yaak_metadata_dataframe(
*, path: PathLike[str], fields: Fields
) -> Mapping[str, pl.DataFrame]:
return YaakMetadataDataFrameBuilder(fields=fields)(path)

0 comments on commit 304b5ef

Please sign in to comment.