diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af3802c..040c385 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: validate-pyproject - repo: https://github.com/crate-ci/typos - rev: v1.24.6 + rev: v1.25.0 hooks: - id: typos @@ -31,7 +31,7 @@ repos: - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.18.0 + rev: 1.18.2 hooks: - id: basedpyright diff --git a/examples/config_templates/frame_reader/video.yaml b/examples/config_templates/frame_reader/video/ffmpeg.yaml similarity index 65% rename from examples/config_templates/frame_reader/video.yaml rename to examples/config_templates/frame_reader/video/ffmpeg.yaml index 3cc68ca..5107b03 100644 --- a/examples/config_templates/frame_reader/video.yaml +++ b/examples/config_templates/frame_reader/video/ffmpeg.yaml @@ -1,5 +1,5 @@ --- -_target_: rbyte.io.frame.VideoFrameReader +_target_: rbyte.io.frame.FfmpegFrameReader path: ??? threads: !!null resize_shorter_side: !!null diff --git a/examples/config_templates/frame_reader/video/vali.yaml b/examples/config_templates/frame_reader/video/vali.yaml new file mode 100644 index 0000000..2441ab0 --- /dev/null +++ b/examples/config_templates/frame_reader/video/vali.yaml @@ -0,0 +1,4 @@ +--- +_target_: rbyte.io.frame.ValiGpuFrameReader +path: ??? +pixel_format_chain: [NV12] diff --git a/examples/config_templates/read_frames.yaml b/examples/config_templates/read_frames.yaml index 2c1ac51..423dc01 100644 --- a/examples/config_templates/read_frames.yaml +++ b/examples/config_templates/read_frames.yaml @@ -3,9 +3,6 @@ defaults: - frame_reader: !!null - _self_ -batch_size: 1 -entity_path: ??? - hydra: output_subdir: !!null run: diff --git a/justfile b/justfile index b1022ed..454061e 100644 --- a/justfile +++ b/justfile @@ -62,7 +62,7 @@ read-frames *ARGS: generate-example-config # rerun server and viewer rerun bind="0.0.0.0" port="9876" ws-server-port="9877" web-viewer-port="9090": - uv run rerun \ + RUST_LOG=debug uv run rerun \ --bind {{ bind }} \ --port {{ port }} \ --ws-server-port {{ ws-server-port }} \ diff --git a/pyproject.toml b/pyproject.toml index 602f285..3c3da7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ mcap = [ ] yaak = ["protobuf", "ptars>=0.0.2rc2"] jpeg = ["simplejpeg>=1.7.6"] -video = ["video-reader-rs>=0.1.4"] +video = ["python-vali>=4.1.4.post0", "video-reader-rs>=0.1.4"] [project.scripts] rbyte-build-table = 'rbyte.scripts.build_table:main' diff --git a/src/rbyte/config/__init__.py b/src/rbyte/config/__init__.py index e69de29..df2d681 100644 --- a/src/rbyte/config/__init__.py +++ b/src/rbyte/config/__init__.py @@ -0,0 +1,3 @@ +from .base import BaseModel, HydraConfig + +__all__ = ["BaseModel", "HydraConfig"] diff --git a/src/rbyte/io/frame/__init__.py b/src/rbyte/io/frame/__init__.py index 14cb1ec..c42e7c5 100644 --- a/src/rbyte/io/frame/__init__.py +++ b/src/rbyte/io/frame/__init__.py @@ -2,16 +2,24 @@ __all__ = ["DirectoryFrameReader"] + try: - from .video import VideoFrameReader + from .mcap import McapFrameReader except ImportError: pass else: - __all__ += ["VideoFrameReader"] + __all__ += ["McapFrameReader"] try: - from .mcap import McapFrameReader + from .video.ffmpeg_reader import FfmpegFrameReader except ImportError: pass else: - __all__ += ["McapFrameReader"] + __all__ += ["FfmpegFrameReader"] + +try: + from .video.vali_reader import ValiGpuFrameReader +except ImportError: + pass +else: + __all__ += ["ValiGpuFrameReader"] diff --git a/src/rbyte/io/frame/base.py b/src/rbyte/io/frame/base.py index 11cc562..2b2d1e3 100644 --- a/src/rbyte/io/frame/base.py +++ b/src/rbyte/io/frame/base.py @@ -7,5 +7,7 @@ @runtime_checkable class FrameReader(Protocol): - def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: ... + def read( + self, indexes: Iterable[int] + ) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: ... def get_available_indexes(self) -> Sequence[int]: ... diff --git a/src/rbyte/io/frame/video/__init__.py b/src/rbyte/io/frame/video/__init__.py index 405f40f..5e2a442 100644 --- a/src/rbyte/io/frame/video/__init__.py +++ b/src/rbyte/io/frame/video/__init__.py @@ -1,3 +1,17 @@ -from .reader import VideoFrameReader +__all__: list[str] = [] -__all__ = ["VideoFrameReader"] +try: + from .ffmpeg_reader import FfmpegFrameReader +except ImportError: + pass + +else: + __all__ += ["FfmpegFrameReader"] + +try: + from .vali_reader import ValiGpuFrameReader +except ImportError: + pass + +else: + __all__ += ["ValiGpuFrameReader"] diff --git a/src/rbyte/io/frame/video/reader.py b/src/rbyte/io/frame/video/ffmpeg_reader.py similarity index 91% rename from src/rbyte/io/frame/video/reader.py rename to src/rbyte/io/frame/video/ffmpeg_reader.py index a3ad0c1..c3bb668 100644 --- a/src/rbyte/io/frame/video/reader.py +++ b/src/rbyte/io/frame/video/ffmpeg_reader.py @@ -1,23 +1,22 @@ from collections.abc import Callable, Iterable, Sequence from functools import partial -from os import PathLike from pathlib import Path from typing import override import torch import video_reader as vr from jaxtyping import UInt8 -from pydantic import NonNegativeInt, validate_call +from pydantic import FilePath, NonNegativeInt, validate_call from torch import Tensor from rbyte.io.frame.base import FrameReader -class VideoFrameReader(FrameReader): +class FfmpegFrameReader(FrameReader): @validate_call def __init__( self, - path: PathLike[str], + path: FilePath, threads: NonNegativeInt | None = None, resize_shorter_side: NonNegativeInt | None = None, with_fallback: bool | None = None, # noqa: FBT001 diff --git a/src/rbyte/io/frame/video/vali_reader.py b/src/rbyte/io/frame/video/vali_reader.py new file mode 100644 index 0000000..40a7689 --- /dev/null +++ b/src/rbyte/io/frame/video/vali_reader.py @@ -0,0 +1,109 @@ +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from itertools import pairwise +from typing import Annotated, override + +import more_itertools as mit +import python_vali as vali +import torch +from jaxtyping import Shaped +from pydantic import ( + BeforeValidator, + ConfigDict, + FilePath, + NonNegativeInt, + validate_call, +) +from structlog import get_logger +from torch import Tensor + +from rbyte.io.frame.base import FrameReader + +logger = get_logger(__name__) + + +class ValiGpuFrameReader(FrameReader): + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + path: FilePath, + gpu_id: NonNegativeInt = 0, + pixel_format_chain: Annotated[ + Sequence[vali.PixelFormat], + BeforeValidator(lambda v: [getattr(vali.PixelFormat, x) for x in v]), + ] = (vali.PixelFormat.RGB, vali.PixelFormat.RGB_PLANAR), + ) -> None: + super().__init__() + + self._gpu_id = gpu_id + + self._decoder = vali.PyDecoder( + input=path.resolve().as_posix(), opts={}, gpu_id=self._gpu_id + ) + + self._pixel_format_chain = ( + (self._decoder.Format, *pixel_format_chain) + if mit.first(pixel_format_chain, default=None) != self._decoder.Format + else pixel_format_chain + ) + + @cached_property + def _surface_converters( + self, + ) -> Mapping[tuple[vali.PixelFormat, vali.PixelFormat], vali.PySurfaceConverter]: + return { + (src_format, dst_format): vali.PySurfaceConverter( + src_format=src_format, dst_format=dst_format, gpu_id=self._gpu_id + ) + for src_format, dst_format in pairwise(self._pixel_format_chain) + } + + @cached_property + def _surfaces(self) -> Mapping[vali.PixelFormat, vali.Surface]: + return { + pixel_format: vali.Surface.Make( + format=pixel_format, + width=self._decoder.Width, + height=self._decoder.Height, + gpu_id=self._gpu_id, + ) + for pixel_format in self._pixel_format_chain + } + + def _read(self, index: int) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"]: + seek_ctx = vali.SeekContext(index, vali.SeekMode.EXACT_FRAME) + success, details = self._decoder.DecodeSingleSurface( # pyright: ignore[reportUnknownMemberType] + self._surfaces[self._decoder.Format], seek_ctx + ) + if not success: + logger.error(msg := "failed to decode surface", details=details) + + raise RuntimeError(msg) + + for (src_format, dst_format), converter in self._surface_converters.items(): + success, details = converter.Run( # pyright: ignore[reportUnknownMemberType] + (src := self._surfaces[src_format]), (dst := self._surfaces[dst_format]) + ) + if not success: + logger.error( + msg := "failed to convert surface", + src=src, + dst=dst, + details=details, + ) + + raise RuntimeError(msg) + + surface = self._surfaces[self._pixel_format_chain[-1]] + + return torch.from_dlpack(surface).clone().detach() # pyright: ignore[reportPrivateImportUsage] + + @override + def read( + self, indexes: Iterable[int] + ) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: + return torch.stack([self._read(index) for index in indexes]) + + @override + def get_available_indexes(self) -> Sequence[int]: + return range(self._decoder.NumFrames) diff --git a/src/rbyte/io/table/yaak/idl-repo b/src/rbyte/io/table/yaak/idl-repo index 7247555..ec4132c 160000 --- a/src/rbyte/io/table/yaak/idl-repo +++ b/src/rbyte/io/table/yaak/idl-repo @@ -1 +1 @@ -Subproject commit 7247555a0bbfb98dbafa91766511773cb26141ad +Subproject commit ec4132c834d22b790c8160f4aa4ce6a7295f87e3 diff --git a/src/rbyte/scripts/read_frames.py b/src/rbyte/scripts/read_frames.py index dec71c2..504398c 100644 --- a/src/rbyte/scripts/read_frames.py +++ b/src/rbyte/scripts/read_frames.py @@ -1,23 +1,53 @@ -from typing import cast +from collections.abc import Mapping +from typing import Annotated import hydra import more_itertools as mit import rerun as rr import torch -from hydra.utils import instantiate -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf +from pydantic import BeforeValidator, NonNegativeInt from structlog import get_logger from tqdm import tqdm +from rbyte.config import BaseModel, HydraConfig from rbyte.io.frame.base import FrameReader logger = get_logger(__name__) +TORCH_TO_RERUN_DTYPE: Mapping[torch.dtype, rr.ChannelDatatype] = { + torch.uint8: rr.ChannelDatatype.U8, + torch.uint16: rr.ChannelDatatype.U16, + torch.uint32: rr.ChannelDatatype.U32, + torch.uint64: rr.ChannelDatatype.U64, + torch.float16: rr.ChannelDatatype.F16, + torch.float32: rr.ChannelDatatype.F32, + torch.float64: rr.ChannelDatatype.F64, +} + + +class Config(BaseModel): + frame_reader: HydraConfig[FrameReader] + batch_size: NonNegativeInt = 1 + application_id: str = "rbyte-read-frames" + entity_path: str = "frames" + pixel_format: ( + Annotated[rr.PixelFormat, BeforeValidator(rr.PixelFormat.auto)] | None + ) = None + color_model: ( + Annotated[rr.ColorModel, BeforeValidator(rr.ColorModel.auto)] | None + ) = None + + @hydra.main(version_base=None) -def main(config: DictConfig) -> None: - frame_reader = cast(FrameReader, instantiate(config.frame_reader)) - rr.init("rbyte", spawn=True) +def main(_config: DictConfig) -> None: + config = Config.model_validate(OmegaConf.to_object(_config)) + + frame_reader = config.frame_reader.instantiate() + + rr.init(config.application_id, spawn=True) + rr.log(config.entity_path, [rr.Image.indicator()], static=True, strict=True) for frame_indexes in mit.chunked( tqdm(sorted(frame_reader.get_available_indexes())), @@ -25,37 +55,35 @@ def main(config: DictConfig) -> None: strict=False, ): frames = frame_reader.read(frame_indexes) - match frames.shape, frames.dtype: - case ((_, height, width, 3), torch.uint8): - rr.log( - config.entity_path, - [ - rr.components.ImageFormat( - height=height, - width=width, - color_model="RGB", - channel_datatype="U8", - ), - rr.Image.indicator(), - ], - static=True, - strict=True, + + match (config.pixel_format, config.color_model, frames.shape): + case None, rr.ColorModel.RGB, (_, height, width, 3) | (_, 3, height, width): + image_format = rr.components.ImageFormat( + width=width, + height=height, + color_model=rr.ColorModel.RGB, + channel_datatype=TORCH_TO_RERUN_DTYPE[frames.dtype], ) - rr.send_columns( - config.entity_path, - times=[rr.TimeSequenceColumn("frame_index", frame_indexes)], - components=[ - rr.components.ImageBufferBatch( - frames.flatten(start_dim=1, end_dim=-1).numpy() # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] - ) - ], - strict=True, + case rr.PixelFormat.NV12, None, (_, dim, width): + image_format = rr.components.ImageFormat( + width=width, height=int(dim / 1.5), pixel_format=rr.PixelFormat.NV12 ) case _: raise NotImplementedError + rr.send_columns( + config.entity_path, + times=[rr.TimeSequenceColumn("frame_index", frame_indexes)], + components=[ + rr.components.ImageFormatBatch([image_format] * len(frame_indexes)), + rr.components.ImageBufferBatch( + frames.flatten(1, -1).cpu().numpy() # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + ), + ], + ) + if __name__ == "__main__": main()