Skip to content

Commit

Permalink
feat: sample builders
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov committed Nov 20, 2024
1 parent 15f45bf commit dc93ce6
Show file tree
Hide file tree
Showing 23 changed files with 193 additions and 182 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
repos:
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.22
rev: v0.23
hooks:
- id: validate-pyproject

Expand All @@ -11,14 +11,14 @@ repos:
- id: pyupgrade

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.7.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
rev: 1.21.0
rev: 1.21.1
hooks:
- id: basedpyright

Expand Down
7 changes: 2 additions & 5 deletions config/_templates/dataset/carla.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.RollingWindowSampleBuilder
index_column: _idx_
length: 1
stride: 1
min_step: 1
filter: !!null
period: 1i
7 changes: 2 additions & 5 deletions config/_templates/dataset/mimicgen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.RollingWindowSampleBuilder
index_column: _idx_
length: 1
stride: 1
min_step: 1
filter: !!null
period: 1i
10 changes: 3 additions & 7 deletions config/_templates/dataset/nuscenes/mcap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ inputs:
fields:
#@ for topic in camera_topics.values():
(@=topic@):
_idx_:
log_time:
_target_: polars.Datetime
time_unit: ns

_idx_:
#@ end

/odom:
Expand Down Expand Up @@ -94,9 +93,6 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.RollingWindowSampleBuilder
index_column: mcap/(@=camera_topics.values()[0]@)/_idx_
length: 1
stride: 1
min_step: 1
filter: !!null
period: 1i
8 changes: 2 additions & 6 deletions config/_templates/dataset/nuscenes/rrd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.RollingWindowSampleBuilder
index_column: rrd/(@=camera_entities.values()[0]@)/_idx_
length: 1
stride: 1
min_step: 1
filter: !!null

period: 1i
12 changes: 6 additions & 6 deletions config/_templates/dataset/yaak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ inputs:
_target_: polars.Datetime
time_unit: ns

frame_idx: polars.UInt32
frame_idx: polars.Int32
camera_name:
_target_: polars.Enum
categories:
Expand Down Expand Up @@ -130,10 +130,10 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.FixedWindowSampleBuilder
index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx
length: 1
stride: 1
min_step: 1
every: 6i
period: 6i
filter: |
array_mean(`meta/VehicleMotion/speed`) > 40
array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6
and array_mean(`meta/VehicleMotion/speed`) > 40
23 changes: 12 additions & 11 deletions config/_templates/dataset/zod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ inputs:
index_column: camera_front_blur/timestamp
source:
_target_: rbyte.io.PathTensorSource
path: "${data_dir}/zod/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg"
path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg"
decoder:
_target_: simplejpeg.decode_jpeg
_partial_: true
Expand All @@ -21,15 +21,15 @@ inputs:
index_column: lidar_velodyne/timestamp
source:
_target_: rbyte.io.NumpyTensorSource
path: "${data_dir}/zod/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy"
path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy"
select: ["x", "y", "z"]

table_builder:
_target_: rbyte.io.TableBuilder
_convert_: all
readers:
camera_front_blur:
path: "${data_dir}/zod/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg"
path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg"
reader:
_target_: rbyte.io.PathTableReader
_recursive_: false
Expand All @@ -39,7 +39,7 @@ inputs:
time_unit: ns

lidar_velodyne:
path: "${data_dir}/zod/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy"
path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy"
reader:
_target_: rbyte.io.PathTableReader
_recursive_: false
Expand All @@ -49,7 +49,7 @@ inputs:
time_unit: ns

vehicle_data:
path: "${data_dir}/zod/sequences/000002_short/vehicle_data.hdf5"
path: "${data_dir}/sequences/000002_short/vehicle_data.hdf5"
reader:
_target_: rbyte.io.Hdf5TableReader
_recursive_: false
Expand Down Expand Up @@ -82,7 +82,7 @@ inputs:
timestamp:
method: asof
strategy: nearest
tolerance: 50ms
tolerance: 100ms

vehicle_data:
ego_vehicle_controls:
Expand All @@ -91,17 +91,17 @@ inputs:
timestamp/nanoseconds/value:
method: asof
strategy: nearest
tolerance: 50ms
tolerance: 100ms

acceleration_pedal/ratio/unitless/value:
method: asof
strategy: nearest
tolerance: 50ms
tolerance: 100ms

steering_wheel_angle/angle/radians/value:
method: asof
strategy: nearest
tolerance: 50ms
tolerance: 100ms

satellite:
key: timestamp/nanoseconds/value
Expand All @@ -110,5 +110,6 @@ inputs:
method: interp

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
length: 1
_target_: rbyte.FixedWindowSampleBuilder
index_column: camera_front_blur/timestamp
every: 300ms
4 changes: 2 additions & 2 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ generate-config:
--strict

test *ARGS: generate-config
uv run pytest --capture=no {{ ARGS }}
uv run --all-extras pytest --capture=no {{ ARGS }}

notebook FILE *ARGS: sync generate-config
uv run --with=jupyter,jupyterlab-vim,rerun-notebook jupyter lab {{ FILE }} {{ ARGS }}
uv run --all-extras --with=jupyter,jupyterlab-vim,rerun-notebook jupyter lab {{ FILE }} {{ ARGS }}

[group('scripts')]
visualize *ARGS: generate-config
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rbyte"
version = "0.7.0"
version = "0.8.0"
description = "Multimodal PyTorch dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
Expand Down
8 changes: 7 additions & 1 deletion src/rbyte/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from importlib.metadata import version

from .dataset import Dataset
from .sample import FixedWindowSampleBuilder, RollingWindowSampleBuilder

__version__ = version(__package__ or __name__)

__all__ = ["Dataset", "__version__"]
__all__ = [
"Dataset",
"FixedWindowSampleBuilder",
"RollingWindowSampleBuilder",
"__version__",
]
12 changes: 5 additions & 7 deletions src/rbyte/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def __init__(
super().__init__()

_sample_builder = sample_builder.instantiate()
samples: Mapping[str, pl.LazyFrame] = {}
samples: Mapping[str, pl.DataFrame] = {}
for input_id, input_cfg in inputs.items():
with bound_contextvars(input_id=input_id):
table = input_cfg.table_builder.instantiate().build().lazy()
table = input_cfg.table_builder.instantiate().build()
samples[input_id] = _sample_builder.build(table)
logger.debug(
"built samples",
rows=table.select(pl.len()).collect().item(),
samples=samples[input_id].select(pl.len()).collect().item(),
rows=table.select(pl.len()).item(),
samples=samples[input_id].select(pl.len()).item(),
)

input_id_enum = pl.Enum(sorted(samples))
Expand All @@ -85,12 +85,11 @@ def __init__(
)
.sort(Column.input_id)
.with_row_index(Column.sample_idx)
.collect()
.rechunk()
)

self._sources: pl.DataFrame = (
pl.LazyFrame(
pl.DataFrame(
[
{
Column.input_id: input_id,
Expand All @@ -112,7 +111,6 @@ def __init__(
.explode(k)
.unnest(k)
.select(Column.input_id, pl.exclude(Column.input_id).name.prefix(f"{k}."))
.collect()
.rechunk()
)

Expand Down
5 changes: 3 additions & 2 deletions src/rbyte/sample/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .greedy_builder import GreedySampleBuilder
from .fixed_window import FixedWindowSampleBuilder
from .rolling_window import RollingWindowSampleBuilder

__all__ = ["GreedySampleBuilder"]
__all__ = ["FixedWindowSampleBuilder", "RollingWindowSampleBuilder"]
2 changes: 1 addition & 1 deletion src/rbyte/sample/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

@runtime_checkable
class SampleBuilder(Protocol):
def build(self, source: pl.LazyFrame) -> pl.LazyFrame: ...
def build(self, source: pl.DataFrame) -> pl.DataFrame: ...
47 changes: 47 additions & 0 deletions src/rbyte/sample/fixed_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from datetime import timedelta
from typing import Literal, override
from uuid import uuid4

import polars as pl
from polars._typing import ClosedInterval
from pydantic import validate_call

from .base import SampleBuilder


class FixedWindowSampleBuilder(SampleBuilder):
@validate_call
def __init__(
self,
*,
index_column: str,
every: str | timedelta,
period: str | timedelta | None = None,
closed: ClosedInterval = "left",
filter: str | None = None, # noqa: A002
) -> None:
self._index_column: pl.Expr = pl.col(index_column)
self._every: str | timedelta = every
self._period: str | timedelta | None = period
self._closed: ClosedInterval = closed
self._filter: str | Literal[True] = filter if filter is not None else True

@override
def build(self, source: pl.DataFrame) -> pl.DataFrame:
return (
source.sort(self._index_column)
.with_columns(self._index_column.alias(_index_column := uuid4().hex))
.group_by_dynamic(
index_column=_index_column,
every=self._every,
period=self._period,
closed=self._closed,
label="datapoint",
start_by="datapoint",
)
.agg(pl.all())
.sql(f"select * from self where ({self._filter})") # noqa: S608
.filter(self._index_column.list.len() > 0)
.sort(_index_column)
.drop(_index_column)
)
Loading

0 comments on commit dc93ce6

Please sign in to comment.