Skip to content

Commit

Permalink
feat: add VALI-based (GPU) video reader
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov committed Oct 1, 2024
1 parent f2c0ff8 commit 11d2cae
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 50 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: validate-pyproject

- repo: https://github.com/crate-ci/typos
rev: v1.24.6
rev: v1.25.0
hooks:
- id: typos

Expand All @@ -31,7 +31,7 @@ repos:
- id: ruff-format

- repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
rev: 1.18.0
rev: 1.18.2
hooks:
- id: basedpyright

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
_target_: rbyte.io.frame.VideoFrameReader
_target_: rbyte.io.frame.FfmpegFrameReader
path: ???
threads: !!null
resize_shorter_side: !!null
Expand Down
4 changes: 4 additions & 0 deletions examples/config_templates/frame_reader/video/vali.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
_target_: rbyte.io.frame.ValiGpuFrameReader
path: ???
pixel_format_chain: [NV12]
3 changes: 0 additions & 3 deletions examples/config_templates/read_frames.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ defaults:
- frame_reader: !!null
- _self_

batch_size: 1
entity_path: ???

hydra:
output_subdir: !!null
run:
Expand Down
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ read-frames *ARGS: generate-example-config

# rerun server and viewer
rerun bind="0.0.0.0" port="9876" ws-server-port="9877" web-viewer-port="9090":
uv run rerun \
RUST_LOG=debug uv run rerun \
--bind {{ bind }} \
--port {{ port }} \
--ws-server-port {{ ws-server-port }} \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mcap = [
]
yaak = ["protobuf", "ptars>=0.0.2rc2"]
jpeg = ["simplejpeg>=1.7.6"]
video = ["video-reader-rs>=0.1.4"]
video = ["python-vali>=4.1.4.post0", "video-reader-rs>=0.1.4"]

[project.scripts]
rbyte-build-table = 'rbyte.scripts.build_table:main'
Expand Down
3 changes: 3 additions & 0 deletions src/rbyte/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseModel, HydraConfig

__all__ = ["BaseModel", "HydraConfig"]
16 changes: 12 additions & 4 deletions src/rbyte/io/frame/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@

__all__ = ["DirectoryFrameReader"]


try:
from .video import VideoFrameReader
from .mcap import McapFrameReader
except ImportError:
pass
else:
__all__ += ["VideoFrameReader"]
__all__ += ["McapFrameReader"]

try:
from .mcap import McapFrameReader
from .video.ffmpeg_reader import FfmpegFrameReader
except ImportError:
pass
else:
__all__ += ["McapFrameReader"]
__all__ += ["FfmpegFrameReader"]

try:
from .video.vali_reader import ValiGpuFrameReader
except ImportError:
pass
else:
__all__ += ["ValiGpuFrameReader"]
4 changes: 3 additions & 1 deletion src/rbyte/io/frame/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@

@runtime_checkable
class FrameReader(Protocol):
def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: ...
def read(
self, indexes: Iterable[int]
) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: ...
def get_available_indexes(self) -> Sequence[int]: ...
18 changes: 16 additions & 2 deletions src/rbyte/io/frame/video/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .reader import VideoFrameReader
__all__: list[str] = []

__all__ = ["VideoFrameReader"]
try:
from .ffmpeg_reader import FfmpegFrameReader
except ImportError:
pass

else:
__all__ += ["FfmpegFrameReader"]

try:
from .vali_reader import ValiGpuFrameReader
except ImportError:
pass

else:
__all__ += ["ValiGpuFrameReader"]
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from os import PathLike
from pathlib import Path
from typing import override

import torch
import video_reader as vr
from jaxtyping import UInt8
from pydantic import NonNegativeInt, validate_call
from pydantic import FilePath, NonNegativeInt, validate_call
from torch import Tensor

from rbyte.io.frame.base import FrameReader


class VideoFrameReader(FrameReader):
class FfmpegFrameReader(FrameReader):
@validate_call
def __init__(
self,
path: PathLike[str],
path: FilePath,
threads: NonNegativeInt | None = None,
resize_shorter_side: NonNegativeInt | None = None,
with_fallback: bool | None = None, # noqa: FBT001
Expand Down
109 changes: 109 additions & 0 deletions src/rbyte/io/frame/video/vali_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from itertools import pairwise
from typing import Annotated, override

import more_itertools as mit
import python_vali as vali
import torch
from jaxtyping import Shaped
from pydantic import (
BeforeValidator,
ConfigDict,
FilePath,
NonNegativeInt,
validate_call,
)
from structlog import get_logger
from torch import Tensor

from rbyte.io.frame.base import FrameReader

logger = get_logger(__name__)


class ValiGpuFrameReader(FrameReader):
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(
self,
path: FilePath,
gpu_id: NonNegativeInt = 0,
pixel_format_chain: Annotated[
Sequence[vali.PixelFormat],
BeforeValidator(lambda v: [getattr(vali.PixelFormat, x) for x in v]),
] = (vali.PixelFormat.RGB, vali.PixelFormat.RGB_PLANAR),
) -> None:
super().__init__()

self._gpu_id = gpu_id

self._decoder = vali.PyDecoder(
input=path.resolve().as_posix(), opts={}, gpu_id=self._gpu_id
)

self._pixel_format_chain = (
(self._decoder.Format, *pixel_format_chain)
if mit.first(pixel_format_chain, default=None) != self._decoder.Format
else pixel_format_chain
)

@cached_property
def _surface_converters(
self,
) -> Mapping[tuple[vali.PixelFormat, vali.PixelFormat], vali.PySurfaceConverter]:
return {
(src_format, dst_format): vali.PySurfaceConverter(
src_format=src_format, dst_format=dst_format, gpu_id=self._gpu_id
)
for src_format, dst_format in pairwise(self._pixel_format_chain)
}

@cached_property
def _surfaces(self) -> Mapping[vali.PixelFormat, vali.Surface]:
return {
pixel_format: vali.Surface.Make(
format=pixel_format,
width=self._decoder.Width,
height=self._decoder.Height,
gpu_id=self._gpu_id,
)
for pixel_format in self._pixel_format_chain
}

def _read(self, index: int) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"]:
seek_ctx = vali.SeekContext(index, vali.SeekMode.EXACT_FRAME)
success, details = self._decoder.DecodeSingleSurface( # pyright: ignore[reportUnknownMemberType]
self._surfaces[self._decoder.Format], seek_ctx
)
if not success:
logger.error(msg := "failed to decode surface", details=details)

raise RuntimeError(msg)

for (src_format, dst_format), converter in self._surface_converters.items():
success, details = converter.Run( # pyright: ignore[reportUnknownMemberType]
(src := self._surfaces[src_format]), (dst := self._surfaces[dst_format])
)
if not success:
logger.error(
msg := "failed to convert surface",
src=src,
dst=dst,
details=details,
)

raise RuntimeError(msg)

surface = self._surfaces[self._pixel_format_chain[-1]]

return torch.from_dlpack(surface).clone().detach() # pyright: ignore[reportPrivateImportUsage]

@override
def read(
self, indexes: Iterable[int]
) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]:
return torch.stack([self._read(index) for index in indexes])

@override
def get_available_indexes(self) -> Sequence[int]:
return range(self._decoder.NumFrames)
2 changes: 1 addition & 1 deletion src/rbyte/io/table/yaak/idl-repo
88 changes: 58 additions & 30 deletions src/rbyte/scripts/read_frames.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,89 @@
from typing import cast
from collections.abc import Mapping
from typing import Annotated

import hydra
import more_itertools as mit
import rerun as rr
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pydantic import BeforeValidator, NonNegativeInt
from structlog import get_logger
from tqdm import tqdm

from rbyte.config import BaseModel, HydraConfig
from rbyte.io.frame.base import FrameReader

logger = get_logger(__name__)


TORCH_TO_RERUN_DTYPE: Mapping[torch.dtype, rr.ChannelDatatype] = {
torch.uint8: rr.ChannelDatatype.U8,
torch.uint16: rr.ChannelDatatype.U16,
torch.uint32: rr.ChannelDatatype.U32,
torch.uint64: rr.ChannelDatatype.U64,
torch.float16: rr.ChannelDatatype.F16,
torch.float32: rr.ChannelDatatype.F32,
torch.float64: rr.ChannelDatatype.F64,
}


class Config(BaseModel):
frame_reader: HydraConfig[FrameReader]
batch_size: NonNegativeInt = 1
application_id: str = "rbyte-read-frames"
entity_path: str = "frames"
pixel_format: (
Annotated[rr.PixelFormat, BeforeValidator(rr.PixelFormat.auto)] | None
) = None
color_model: (
Annotated[rr.ColorModel, BeforeValidator(rr.ColorModel.auto)] | None
) = None


@hydra.main(version_base=None)
def main(config: DictConfig) -> None:
frame_reader = cast(FrameReader, instantiate(config.frame_reader))
rr.init("rbyte", spawn=True)
def main(_config: DictConfig) -> None:
config = Config.model_validate(OmegaConf.to_object(_config))

frame_reader = config.frame_reader.instantiate()

rr.init(config.application_id, spawn=True)
rr.log(config.entity_path, [rr.Image.indicator()], static=True, strict=True)

for frame_indexes in mit.chunked(
tqdm(sorted(frame_reader.get_available_indexes())),
config.batch_size,
strict=False,
):
frames = frame_reader.read(frame_indexes)
match frames.shape, frames.dtype:
case ((_, height, width, 3), torch.uint8):
rr.log(
config.entity_path,
[
rr.components.ImageFormat(
height=height,
width=width,
color_model="RGB",
channel_datatype="U8",
),
rr.Image.indicator(),
],
static=True,
strict=True,

match (config.pixel_format, config.color_model, frames.shape):
case None, rr.ColorModel.RGB, (_, height, width, 3) | (_, 3, height, width):
image_format = rr.components.ImageFormat(
width=width,
height=height,
color_model=rr.ColorModel.RGB,
channel_datatype=TORCH_TO_RERUN_DTYPE[frames.dtype],
)

rr.send_columns(
config.entity_path,
times=[rr.TimeSequenceColumn("frame_index", frame_indexes)],
components=[
rr.components.ImageBufferBatch(
frames.flatten(start_dim=1, end_dim=-1).numpy() # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType]
)
],
strict=True,
case rr.PixelFormat.NV12, None, (_, dim, width):
image_format = rr.components.ImageFormat(
width=width, height=int(dim / 1.5), pixel_format=rr.PixelFormat.NV12
)

case _:
raise NotImplementedError

rr.send_columns(
config.entity_path,
times=[rr.TimeSequenceColumn("frame_index", frame_indexes)],
components=[
rr.components.ImageFormatBatch([image_format] * len(frame_indexes)),
rr.components.ImageBufferBatch(
frames.flatten(1, -1).cpu().numpy() # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType]
),
],
)


if __name__ == "__main__":
main()

0 comments on commit 11d2cae

Please sign in to comment.