Skip to content

Commit

Permalink
Updated metadata; all tests pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Dec 20, 2024
1 parent b349282 commit 081a5bb
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 91 deletions.
72 changes: 47 additions & 25 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(

int64_t pts = getPts(streamInfo, streamMetadata, frameIndex);
setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase));
return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor);
return getNextFrameNoDemuxInternal(preAllocatedOutputTensor);
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
Expand Down Expand Up @@ -1252,19 +1252,19 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps(

std::vector<int64_t> frameIndices(timestamps.size());
for (auto i = 0; i < timestamps.size(); ++i) {
auto framePts = timestamps[i];
auto frameSeconds = timestamps[i];
TORCH_CHECK(
framePts >= minSeconds && framePts < maxSeconds,
"frame pts is " + std::to_string(framePts) + "; must be in range [" +
std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) +
").");
frameSeconds >= minSeconds && frameSeconds < maxSeconds,
"frame pts is " + std::to_string(frameSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");

auto it = std::lower_bound(
stream.allFrames.begin(),
stream.allFrames.end(),
framePts,
[&stream](const FrameInfo& info, double framePts) {
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
frameSeconds,
[&stream](const FrameInfo& info, double frameSeconds) {
return ptsToSeconds(info.nextPts, stream.timeBase) <= frameSeconds;
});
int64_t frameIndex = it - stream.allFrames.begin();
frameIndices[i] = frameIndex;
Expand All @@ -1284,15 +1284,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps(
BatchDecodedOutput output(timestamps.size(), options, streamMetadata);

for (auto i = 0; i < timestamps.size(); ++i) {
auto framePts = timestamps[i];
auto frameSeconds = timestamps[i];
TORCH_CHECK(
framePts >= minSeconds && framePts < maxSeconds,
"frame pts is " + std::to_string(framePts) + "; must be in range [" +
std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) +
").");
frameSeconds >= minSeconds && frameSeconds < maxSeconds,
"frame pts is " + std::to_string(frameSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");

DecodedOutput singleOut =
getFramePlayedAtTimestampNoDemuxInternal(framePts, output.frames[i]);
DecodedOutput singleOut = getFramePlayedAtTimestampNoDemuxInternal(
frameSeconds, output.frames[i]);
output.ptsSeconds[i] = singleOut.ptsSeconds;
output.durationSeconds[i] = singleOut.durationSeconds;
}
Expand Down Expand Up @@ -1462,21 +1462,43 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
// after the fact. We can't preallocate the final tensor because we don't
// know how many frames we're going to decode up front.

setCursorPtsInSeconds(startSeconds);
DecodedOutput singleOut = getNextFrameNoDemux();
DecodedOutput singleOut =
getFramePlayedAtTimestampNoDemuxInternal(startSeconds);

std::vector<torch::Tensor> frames = {singleOut.frame};
std::vector<double> ptsSeconds = {singleOut.ptsSeconds};
std::vector<double> durationSeconds = {singleOut.durationSeconds};
std::vector<torch::Tensor> frames;
std::vector<double> ptsSeconds;
std::vector<double> durationSeconds;

while (singleOut.ptsSeconds < stopSeconds) {
singleOut = getNextFrameNoDemux();
// Note that we only know we've decoded all frames in the range when we have
// decoded the first frame outside of the range. That is, we have to decode
// one frame past where we want to stop, and conclude from its pts that all
// of the prior frames comprises our range. That means we decode one extra
// frame; we don't return it, but we decode it.
//
// This algorithm works fine except when stopSeconds is the duration of the
// video. In that case, we're going to hit the end-of-file exception.
//
// We could avoid decoding an extra frame, and the end-of-file exception, by
// using the currently decoded frame's duration to know that the next frame
// is outside of our range. This would be more efficient. However, up until
// now we have avoided relying on a frame's duration to determine if a frame
// is played during a time window. So there is a potential TODO here where
// we relax that principle and just do the math to avoid the extra decode.
bool eof = false;
while (singleOut.ptsSeconds < stopSeconds && !eof) {
frames.push_back(singleOut.frame);
ptsSeconds.push_back(singleOut.ptsSeconds);
durationSeconds.push_back(singleOut.durationSeconds);

try {
singleOut = getNextFrameNoDemuxInternal();
} catch (EndOfFileException e) {
eof = true;
}
}

BatchDecodedOutput output(frames, ptsSeconds, durationSeconds);
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
return output;

} else {
Expand All @@ -1494,12 +1516,12 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
}

VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() {
auto output = getNextFrameOutputNoDemuxInternal();
auto output = getNextFrameNoDemuxInternal();
output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame);
return output;
}

VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal(
VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemuxInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
auto rawOutput = getNextRawDecodedOutputNoDemux();
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class VideoDecoder {

// ---- SINGLE FRAME SEEK AND DECODING API ----
// Places the cursor at the first frame on or after the position in seconds.
// Calling getNextFrameOutputNoDemuxInternal() will return the first frame at
// Calling getNextFrameNoDemuxInternal() will return the first frame at
// or after this position.
void setCursorPtsInSeconds(double seconds);
// This is an internal structure that is used to store the decoded output
Expand Down Expand Up @@ -433,7 +433,7 @@ class VideoDecoder {
DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

DecodedOutput getNextFrameOutputNoDemuxInternal(
DecodedOutput getNextFrameNoDemuxInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

SeekMode seekMode_;
Expand Down
45 changes: 36 additions & 9 deletions src/torchcodec/decoders/_core/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class VideoStreamMetadata:
content (the scan doesn't involve decoding). This is more accurate
than ``num_frames_from_header``. We recommend using the
``num_frames`` attribute instead. (int or None)."""
begin_stream_seconds: Optional[float]
begin_stream_seconds_from_content: Optional[float]
"""Beginning of the stream, in seconds (float or None).
Conceptually, this corresponds to the first frame's :term:`pts`. It is
computed as min(frame.pts) across all frames in the stream. Usually, this is
equal to 0."""
end_stream_seconds: Optional[float]
end_stream_seconds_from_content: Optional[float]
"""End of the stream, in seconds (float or None).
Conceptually, this corresponds to last_frame.pts + last_frame.duration. It
is computed as max(frame.pts + frame.duration) across all frames in the
Expand Down Expand Up @@ -81,9 +81,15 @@ def duration_seconds(self) -> Optional[float]:
from the actual frames if a :term:`scan` was performed. Otherwise we
fall back to ``duration_seconds_from_header``.
"""
if self.end_stream_seconds is None or self.begin_stream_seconds is None:
if (
self.end_stream_seconds_from_content is None
or self.begin_stream_seconds_from_content is None
):
return self.duration_seconds_from_header
return self.end_stream_seconds - self.begin_stream_seconds
return (
self.end_stream_seconds_from_content
- self.begin_stream_seconds_from_content
)

@property
def average_fps(self) -> Optional[float]:
Expand All @@ -92,12 +98,29 @@ def average_fps(self) -> Optional[float]:
Otherwise we fall back to ``average_fps_from_header``.
"""
if (
self.end_stream_seconds is None
or self.begin_stream_seconds is None
self.end_stream_seconds_from_content is None
or self.begin_stream_seconds_from_content is None
or self.num_frames is None
):
return self.average_fps_from_header
return self.num_frames / (self.end_stream_seconds - self.begin_stream_seconds)
return self.num_frames / (
self.end_stream_seconds_from_content
- self.begin_stream_seconds_from_content
)

@property
def begin_stream_seconds(self) -> float:
"""TODO."""
if self.begin_stream_seconds_from_content is None:
return 0
return self.begin_stream_seconds_from_content

@property
def end_stream_seconds(self) -> Optional[float]:
"""TODO."""
if self.end_stream_seconds_from_content is None:
return self.duration_seconds
return self.end_stream_seconds_from_content

def __repr__(self):
# Overridden because properites are not printed by default.
Expand Down Expand Up @@ -152,8 +175,12 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata:
bit_rate=stream_dict.get("bitRate"),
num_frames_from_header=stream_dict.get("numFrames"),
num_frames_from_content=stream_dict.get("numFramesFromScan"),
begin_stream_seconds=stream_dict.get("minPtsSecondsFromScan"),
end_stream_seconds=stream_dict.get("maxPtsSecondsFromScan"),
begin_stream_seconds_from_content=stream_dict.get(
"minPtsSecondsFromScan"
),
end_stream_seconds_from_content=stream_dict.get(
"maxPtsSecondsFromScan"
),
codec=stream_dict.get("codec"),
width=stream_dict.get("width"),
height=stream_dict.get("height"),
Expand Down
55 changes: 18 additions & 37 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,44 +116,25 @@ def __init__(
self._decoder, stream_index
)

if seek_mode == "exact":
if self.metadata.num_frames_from_content is None:
raise ValueError(
"The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS
)
self._num_frames = self.metadata.num_frames_from_content

if self.metadata.begin_stream_seconds is None:
raise ValueError(
"The minimum pts value in seconds is unknown. "
+ _ERROR_REPORTING_INSTRUCTIONS
)
self._begin_stream_seconds = self.metadata.begin_stream_seconds

if self.metadata.end_stream_seconds is None:
raise ValueError(
"The maximum pts value in seconds is unknown. "
+ _ERROR_REPORTING_INSTRUCTIONS
)
self._end_stream_seconds = self.metadata.end_stream_seconds
elif seek_mode == "approximate":
if self.metadata.num_frames_from_header is None:
raise ValueError(
"The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS
)
self._num_frames = self.metadata.num_frames_from_header

self._begin_stream_seconds = 0

if self.metadata.duration_seconds_from_header is None:
raise ValueError(
"The maximum pts value in seconds is unknown. "
+ _ERROR_REPORTING_INSTRUCTIONS
)
self._end_stream_seconds = self.metadata.duration_seconds_from_header
if self.metadata.num_frames is None:
raise ValueError(
"The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS
)
self._num_frames = self.metadata.num_frames

else:
raise ValueError(f"Invalid seek mode: {seek_mode}.")
if self.metadata.begin_stream_seconds is None:
raise ValueError(
"The minimum pts value in seconds is unknown. "
+ _ERROR_REPORTING_INSTRUCTIONS
)
self._begin_stream_seconds = self.metadata.begin_stream_seconds

if self.metadata.end_stream_seconds is None:
raise ValueError(
"The maximum pts value in seconds is unknown. "
+ _ERROR_REPORTING_INSTRUCTIONS
)
self._end_stream_seconds = self.metadata.end_stream_seconds

def __len__(self) -> int:
return self._num_frames
Expand Down
4 changes: 2 additions & 2 deletions test/decoders/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def test_num_frames_fallback(
bit_rate=123,
num_frames_from_header=num_frames_from_header,
num_frames_from_content=num_frames_from_content,
begin_stream_seconds=0,
end_stream_seconds=4,
begin_stream_seconds_from_content=0,
end_stream_seconds_from_content=4,
codec="whatever",
width=123,
height=321,
Expand Down
19 changes: 3 additions & 16 deletions test/samplers/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,23 +590,10 @@ def restore_metadata():
decoder.metadata = original_metadata

with restore_metadata():
decoder.metadata.begin_stream_seconds = None
with pytest.raises(
ValueError, match="Could not infer stream end and start from video metadata"
):
sampler(decoder)

with restore_metadata():
decoder.metadata.end_stream_seconds = None
with pytest.raises(
ValueError, match="Could not infer stream end and start from video metadata"
):
sampler(decoder)

with restore_metadata():
decoder.metadata.begin_stream_seconds = None
decoder.metadata.end_stream_seconds = None
decoder.metadata.begin_stream_seconds_from_content = None
decoder.metadata.end_stream_seconds_from_content = None
decoder.metadata.average_fps_from_header = None
decoder.metadata.duration_seconds_from_header = None
with pytest.raises(ValueError, match="Could not infer average fps"):
sampler(decoder)

Expand Down

0 comments on commit 081a5bb

Please sign in to comment.