From 081a5bb629e6624cd772ceea3354bbca253d0ce4 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 20 Dec 2024 10:16:06 -0800 Subject: [PATCH] Updated metadata; all tests pass. --- .../decoders/_core/VideoDecoder.cpp | 72 ++++++++++++------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 +- src/torchcodec/decoders/_core/_metadata.py | 45 +++++++++--- src/torchcodec/decoders/_video_decoder.py | 55 +++++--------- test/decoders/test_metadata.py | 4 +- test/samplers/test_samplers.py | 19 +---- 6 files changed, 108 insertions(+), 91 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 86ee7286..2a3b3284 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1170,7 +1170,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); - return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor); + return getNextFrameNoDemuxInternal(preAllocatedOutputTensor); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( @@ -1252,19 +1252,19 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( std::vector frameIndices(timestamps.size()); for (auto i = 0; i < timestamps.size(); ++i) { - auto framePts = timestamps[i]; + auto frameSeconds = timestamps[i]; TORCH_CHECK( - framePts >= minSeconds && framePts < maxSeconds, - "frame pts is " + std::to_string(framePts) + "; must be in range [" + - std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + - ")."); + frameSeconds >= minSeconds && frameSeconds < maxSeconds, + "frame pts is " + std::to_string(frameSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); auto it = std::lower_bound( stream.allFrames.begin(), stream.allFrames.end(), - framePts, - [&stream](const FrameInfo& info, double framePts) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; + frameSeconds, + [&stream](const FrameInfo& info, double frameSeconds) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= frameSeconds; }); int64_t frameIndex = it - stream.allFrames.begin(); frameIndices[i] = frameIndex; @@ -1284,15 +1284,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( BatchDecodedOutput output(timestamps.size(), options, streamMetadata); for (auto i = 0; i < timestamps.size(); ++i) { - auto framePts = timestamps[i]; + auto frameSeconds = timestamps[i]; TORCH_CHECK( - framePts >= minSeconds && framePts < maxSeconds, - "frame pts is " + std::to_string(framePts) + "; must be in range [" + - std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + - ")."); + frameSeconds >= minSeconds && frameSeconds < maxSeconds, + "frame pts is " + std::to_string(frameSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); - DecodedOutput singleOut = - getFramePlayedAtTimestampNoDemuxInternal(framePts, output.frames[i]); + DecodedOutput singleOut = getFramePlayedAtTimestampNoDemuxInternal( + frameSeconds, output.frames[i]); output.ptsSeconds[i] = singleOut.ptsSeconds; output.durationSeconds[i] = singleOut.durationSeconds; } @@ -1462,21 +1462,43 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // after the fact. We can't preallocate the final tensor because we don't // know how many frames we're going to decode up front. - setCursorPtsInSeconds(startSeconds); - DecodedOutput singleOut = getNextFrameNoDemux(); + DecodedOutput singleOut = + getFramePlayedAtTimestampNoDemuxInternal(startSeconds); - std::vector frames = {singleOut.frame}; - std::vector ptsSeconds = {singleOut.ptsSeconds}; - std::vector durationSeconds = {singleOut.durationSeconds}; + std::vector frames; + std::vector ptsSeconds; + std::vector durationSeconds; - while (singleOut.ptsSeconds < stopSeconds) { - singleOut = getNextFrameNoDemux(); + // Note that we only know we've decoded all frames in the range when we have + // decoded the first frame outside of the range. That is, we have to decode + // one frame past where we want to stop, and conclude from its pts that all + // of the prior frames comprises our range. That means we decode one extra + // frame; we don't return it, but we decode it. + // + // This algorithm works fine except when stopSeconds is the duration of the + // video. In that case, we're going to hit the end-of-file exception. + // + // We could avoid decoding an extra frame, and the end-of-file exception, by + // using the currently decoded frame's duration to know that the next frame + // is outside of our range. This would be more efficient. However, up until + // now we have avoided relying on a frame's duration to determine if a frame + // is played during a time window. So there is a potential TODO here where + // we relax that principle and just do the math to avoid the extra decode. + bool eof = false; + while (singleOut.ptsSeconds < stopSeconds && !eof) { frames.push_back(singleOut.frame); ptsSeconds.push_back(singleOut.ptsSeconds); durationSeconds.push_back(singleOut.durationSeconds); + + try { + singleOut = getNextFrameNoDemuxInternal(); + } catch (EndOfFileException e) { + eof = true; + } } BatchDecodedOutput output(frames, ptsSeconds, durationSeconds); + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } else { @@ -1494,12 +1516,12 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { } VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { - auto output = getNextFrameOutputNoDemuxInternal(); + auto output = getNextFrameNoDemuxInternal(); output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); return output; } -VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal( +VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 2ecc28ac..5726b8c5 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -161,7 +161,7 @@ class VideoDecoder { // ---- SINGLE FRAME SEEK AND DECODING API ---- // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextFrameOutputNoDemuxInternal() will return the first frame at + // Calling getNextFrameNoDemuxInternal() will return the first frame at // or after this position. void setCursorPtsInSeconds(double seconds); // This is an internal structure that is used to store the decoded output @@ -433,7 +433,7 @@ class VideoDecoder { DecodedOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); - DecodedOutput getNextFrameOutputNoDemuxInternal( + DecodedOutput getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor = std::nullopt); SeekMode seekMode_; diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index e0400fdf..424d3dbc 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -37,12 +37,12 @@ class VideoStreamMetadata: content (the scan doesn't involve decoding). This is more accurate than ``num_frames_from_header``. We recommend using the ``num_frames`` attribute instead. (int or None).""" - begin_stream_seconds: Optional[float] + begin_stream_seconds_from_content: Optional[float] """Beginning of the stream, in seconds (float or None). Conceptually, this corresponds to the first frame's :term:`pts`. It is computed as min(frame.pts) across all frames in the stream. Usually, this is equal to 0.""" - end_stream_seconds: Optional[float] + end_stream_seconds_from_content: Optional[float] """End of the stream, in seconds (float or None). Conceptually, this corresponds to last_frame.pts + last_frame.duration. It is computed as max(frame.pts + frame.duration) across all frames in the @@ -81,9 +81,15 @@ def duration_seconds(self) -> Optional[float]: from the actual frames if a :term:`scan` was performed. Otherwise we fall back to ``duration_seconds_from_header``. """ - if self.end_stream_seconds is None or self.begin_stream_seconds is None: + if ( + self.end_stream_seconds_from_content is None + or self.begin_stream_seconds_from_content is None + ): return self.duration_seconds_from_header - return self.end_stream_seconds - self.begin_stream_seconds + return ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) @property def average_fps(self) -> Optional[float]: @@ -92,12 +98,29 @@ def average_fps(self) -> Optional[float]: Otherwise we fall back to ``average_fps_from_header``. """ if ( - self.end_stream_seconds is None - or self.begin_stream_seconds is None + self.end_stream_seconds_from_content is None + or self.begin_stream_seconds_from_content is None or self.num_frames is None ): return self.average_fps_from_header - return self.num_frames / (self.end_stream_seconds - self.begin_stream_seconds) + return self.num_frames / ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) + + @property + def begin_stream_seconds(self) -> float: + """TODO.""" + if self.begin_stream_seconds_from_content is None: + return 0 + return self.begin_stream_seconds_from_content + + @property + def end_stream_seconds(self) -> Optional[float]: + """TODO.""" + if self.end_stream_seconds_from_content is None: + return self.duration_seconds + return self.end_stream_seconds_from_content def __repr__(self): # Overridden because properites are not printed by default. @@ -152,8 +175,12 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata: bit_rate=stream_dict.get("bitRate"), num_frames_from_header=stream_dict.get("numFrames"), num_frames_from_content=stream_dict.get("numFramesFromScan"), - begin_stream_seconds=stream_dict.get("minPtsSecondsFromScan"), - end_stream_seconds=stream_dict.get("maxPtsSecondsFromScan"), + begin_stream_seconds_from_content=stream_dict.get( + "minPtsSecondsFromScan" + ), + end_stream_seconds_from_content=stream_dict.get( + "maxPtsSecondsFromScan" + ), codec=stream_dict.get("codec"), width=stream_dict.get("width"), height=stream_dict.get("height"), diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 2f36d06d..a0561423 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -116,44 +116,25 @@ def __init__( self._decoder, stream_index ) - if seek_mode == "exact": - if self.metadata.num_frames_from_content is None: - raise ValueError( - "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS - ) - self._num_frames = self.metadata.num_frames_from_content - - if self.metadata.begin_stream_seconds is None: - raise ValueError( - "The minimum pts value in seconds is unknown. " - + _ERROR_REPORTING_INSTRUCTIONS - ) - self._begin_stream_seconds = self.metadata.begin_stream_seconds - - if self.metadata.end_stream_seconds is None: - raise ValueError( - "The maximum pts value in seconds is unknown. " - + _ERROR_REPORTING_INSTRUCTIONS - ) - self._end_stream_seconds = self.metadata.end_stream_seconds - elif seek_mode == "approximate": - if self.metadata.num_frames_from_header is None: - raise ValueError( - "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS - ) - self._num_frames = self.metadata.num_frames_from_header - - self._begin_stream_seconds = 0 - - if self.metadata.duration_seconds_from_header is None: - raise ValueError( - "The maximum pts value in seconds is unknown. " - + _ERROR_REPORTING_INSTRUCTIONS - ) - self._end_stream_seconds = self.metadata.duration_seconds_from_header + if self.metadata.num_frames is None: + raise ValueError( + "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS + ) + self._num_frames = self.metadata.num_frames - else: - raise ValueError(f"Invalid seek mode: {seek_mode}.") + if self.metadata.begin_stream_seconds is None: + raise ValueError( + "The minimum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + self._begin_stream_seconds = self.metadata.begin_stream_seconds + + if self.metadata.end_stream_seconds is None: + raise ValueError( + "The maximum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + self._end_stream_seconds = self.metadata.end_stream_seconds def __len__(self) -> int: return self._num_frames diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 83505b66..8a6830ad 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -92,8 +92,8 @@ def test_num_frames_fallback( bit_rate=123, num_frames_from_header=num_frames_from_header, num_frames_from_content=num_frames_from_content, - begin_stream_seconds=0, - end_stream_seconds=4, + begin_stream_seconds_from_content=0, + end_stream_seconds_from_content=4, codec="whatever", width=123, height=321, diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index d5c7eb44..94225574 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -590,23 +590,10 @@ def restore_metadata(): decoder.metadata = original_metadata with restore_metadata(): - decoder.metadata.begin_stream_seconds = None - with pytest.raises( - ValueError, match="Could not infer stream end and start from video metadata" - ): - sampler(decoder) - - with restore_metadata(): - decoder.metadata.end_stream_seconds = None - with pytest.raises( - ValueError, match="Could not infer stream end and start from video metadata" - ): - sampler(decoder) - - with restore_metadata(): - decoder.metadata.begin_stream_seconds = None - decoder.metadata.end_stream_seconds = None + decoder.metadata.begin_stream_seconds_from_content = None + decoder.metadata.end_stream_seconds_from_content = None decoder.metadata.average_fps_from_header = None + decoder.metadata.duration_seconds_from_header = None with pytest.raises(ValueError, match="Could not infer average fps"): sampler(decoder)