Skip to content

Commit

Permalink
feat: add video reader (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov authored Sep 27, 2024
1 parent e288474 commit f2c0ff8
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 17 deletions.
21 changes: 8 additions & 13 deletions examples/config_templates/dataset/yaak.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#@yaml/text-templated-strings

#@ drives = [
#@ 'Niro102-HQ/2023-05-08--13-59-22',
#@ 'Niro098-HQ/2024-08-26--06-06-03',
#@ ]

#@ cameras = [
#@ 'cam_front_left',
#@ 'cam_left_forward',
#@ 'cam_right_forward',
#@ 'cam_left_backward',
#@ 'cam_right_backward',
#@ ]
---
_target_: rbyte.Dataset
Expand All @@ -21,14 +21,9 @@ inputs:
(@=source_id@):
index_column: "ImageMetadata.(@=source_id@).frame_idx"
reader:
_target_: rbyte.io.frame.DirectoryFrameReader
path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg"
frame_decoder:
_target_: simplejpeg.decode_jpeg
_partial_: true
colorspace: rgb
fastdct: true
fastupsample: true
_target_: rbyte.io.frame.VideoFrameReader
path: "${data_dir}/(@=input_id@)/(@=source_id@).defish.mp4"
resize_shorter_side: 324
#@ end

table:
Expand Down Expand Up @@ -83,7 +78,7 @@ inputs:

frame_idx:
method: asof
tolerance: 10ms
tolerance: 20ms
strategy: nearest
#@ end

Expand Down Expand Up @@ -112,4 +107,4 @@ sample_builder:
stride: 1
min_step: 6
filter: |
array_mean(`VehicleMotion.speed`) > 47
array_lower(`VehicleMotion.speed`) > 80
6 changes: 6 additions & 0 deletions examples/config_templates/frame_reader/video.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
_target_: rbyte.io.frame.VideoFrameReader
path: ???
threads: !!null
resize_shorter_side: !!null
with_fallback: !!null
4 changes: 2 additions & 2 deletions examples/config_templates/logger/rerun/yaak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#@ cameras = [
#@ 'cam_front_left',
#@ 'cam_left_forward',
#@ 'cam_right_forward',
#@ 'cam_left_backward',
#@ 'cam_right_backward',
#@ ]

---
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
dependencies = [
"tensordict @ git+https://github.com/pytorch/tensordict.git@85b6b81",
"torch>=2.4.1",
"polars>=1.8.0",
"polars>=1.8.2",
"pydantic>=2.9.2",
"more-itertools>=10.5.0",
"hydra-core>=1.3.2",
Expand Down Expand Up @@ -44,6 +44,7 @@ mcap = [
]
yaak = ["protobuf", "ptars>=0.0.2rc2"]
jpeg = ["simplejpeg>=1.7.6"]
video = ["video-reader-rs>=0.1.4"]

[project.scripts]
rbyte-build-table = 'rbyte.scripts.build_table:main'
Expand Down
14 changes: 14 additions & 0 deletions src/rbyte/io/frame/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .directory import DirectoryFrameReader

__all__ = ["DirectoryFrameReader"]

try:
from .video import VideoFrameReader
except ImportError:
pass
else:
__all__ += ["VideoFrameReader"]

try:
from .mcap import McapFrameReader
except ImportError:
pass
else:
__all__ += ["McapFrameReader"]
2 changes: 1 addition & 1 deletion src/rbyte/io/frame/directory/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(

@cached_property
def _path_posix(self) -> str:
return self._path.as_posix()
return self._path.resolve().as_posix()

def _decode(self, path: str) -> npt.ArrayLike:
with Path(path).open("rb") as f:
Expand Down
3 changes: 3 additions & 0 deletions src/rbyte/io/frame/video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .reader import VideoFrameReader

__all__ = ["VideoFrameReader"]
47 changes: 47 additions & 0 deletions src/rbyte/io/frame/video/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from os import PathLike
from pathlib import Path
from typing import override

import torch
import video_reader as vr
from jaxtyping import UInt8
from pydantic import NonNegativeInt, validate_call
from torch import Tensor

from rbyte.io.frame.base import FrameReader


class VideoFrameReader(FrameReader):
@validate_call
def __init__(
self,
path: PathLike[str],
threads: NonNegativeInt | None = None,
resize_shorter_side: NonNegativeInt | None = None,
with_fallback: bool | None = None, # noqa: FBT001
) -> None:
super().__init__()
self._path = Path(path).resolve().as_posix()

self._get_batch: Callable[[str, Iterable[int]], UInt8[Tensor, "b h w c"]] = (
partial(
vr.get_batch, # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
threads=threads,
resize_shorter_side=resize_shorter_side,
with_fallback=with_fallback,
)
)

@override
def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]:
batch = self._get_batch(self._path, indexes)

return torch.from_numpy(batch) # pyright: ignore[reportUnknownMemberType]

@override
def get_available_indexes(self) -> Sequence[int]:
num_frames, *_ = vr.get_shape(self._path) # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType, reportUnknownMemberType]

return range(num_frames) # pyright: ignore[reportUnknownArgumentType]

0 comments on commit f2c0ff8

Please sign in to comment.