Skip to content

Commit

Permalink
Generalize logic to recreate colorspace conversion objects
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Dec 16, 2024
1 parent 9653969 commit 6e9fb33
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 76 deletions.
150 changes: 83 additions & 67 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,16 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
}

bool VideoDecoder::SwsContextKey::operator==(
const VideoDecoder::SwsContextKey& other) {
bool VideoDecoder::DecodedFrameContext::operator==(
const VideoDecoder::DecodedFrameContext& other) {
return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight &&
decodedFormat == other.decodedFormat &&
outputWidth == other.outputWidth && outputHeight == other.outputHeight;
expectedWidth == other.expectedWidth &&
expectedHeight == other.expectedHeight;
}

bool VideoDecoder::SwsContextKey::operator!=(
const VideoDecoder::SwsContextKey& other) {
bool VideoDecoder::DecodedFrameContext::operator!=(
const VideoDecoder::DecodedFrameContext& other) {
return !(*this == other);
}

Expand Down Expand Up @@ -313,17 +314,14 @@ std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
return std::unique_ptr<VideoDecoder>(new VideoDecoder(buffer, length));
}

void VideoDecoder::initializeFilterGraph(
void VideoDecoder::createFilterGraph(
StreamInfo& streamInfo,
int expectedOutputHeight,
int expectedOutputWidth) {
FilterState& filterState = streamInfo.filterState;
if (filterState.filterGraph) {
return;
}

filterState.filterGraph.reset(avfilter_graph_alloc());
TORCH_CHECK(filterState.filterGraph.get() != nullptr);

if (streamInfo.options.ffmpegThreadCount.has_value()) {
filterState.filterGraph->nb_threads =
streamInfo.options.ffmpegThreadCount.value();
Expand Down Expand Up @@ -921,12 +919,32 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(

torch::Tensor outputTensor;
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(frame->format);
auto frameContext = DecodedFrameContext{
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight};

if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));

if (!streamInfo.swsContext || streamInfo.prevFrame != frameContext) {
createSwsContext(streamInfo, frameContext, frame->colorspace);
streamInfo.prevFrame = frameContext;
}
int resultHeight =
convertFrameToBufferUsingSwsScale(streamIndex, frame, outputTensor);
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
// If this check failed, it would mean that the frame wasn't reshaped to
// the expected height.
// TODO: Can we do the same check for width?
Expand All @@ -941,16 +959,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
// Note that is a lazy init; we initialize filtergraph the first time
// we have a raw decoded frame. We do this lazily because up until this
// point, we really don't know what the resolution of the frames are
// without modification. In theory, we should be able to get that from the
// stream metadata, but in practice, we have encountered videos where the
// stream metadata had a different resolution from the actual resolution
// of the raw decoded frames.
if (!streamInfo.filterState.filterGraph) {
initializeFilterGraph(
if (!streamInfo.filterState.filterGraph ||
streamInfo.prevFrame != frameContext) {
createFilterGraph(
streamInfo, expectedOutputHeight, expectedOutputWidth);
streamInfo.prevFrame = frameContext;
}
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);

Expand Down Expand Up @@ -1351,7 +1364,53 @@ double VideoDecoder::getPtsSecondsForFrame(
return ptsToSeconds(stream.allFrames[frameIndex].pts, stream.timeBase);
}

int VideoDecoder::convertFrameToBufferUsingSwsScale(
void VideoDecoder::createSwsContext(
StreamInfo& streamInfo,
const DecodedFrameContext& frameContext,
const enum AVColorSpace colorspace) {
SwsContext* swsContext = sws_getContext(
frameContext.decodedWidth,
frameContext.decodedHeight,
frameContext.decodedFormat,
frameContext.expectedWidth,
frameContext.expectedHeight,
AV_PIX_FMT_RGB24,
SWS_BILINEAR,
nullptr,
nullptr,
nullptr);
TORCH_CHECK(swsContext, "sws_getContext() returned nullptr");

int* invTable = nullptr;
int* table = nullptr;
int srcRange, dstRange, brightness, contrast, saturation;
int ret = sws_getColorspaceDetails(
swsContext,
&invTable,
&srcRange,
&table,
&dstRange,
&brightness,
&contrast,
&saturation);
TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1");

const int* colorspaceTable = sws_getCoefficients(colorspace);
ret = sws_setColorspaceDetails(
swsContext,
colorspaceTable,
srcRange,
colorspaceTable,
dstRange,
brightness,
contrast,
saturation);
TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1");

streamInfo.swsContext.reset(swsContext);
}

int VideoDecoder::convertFrameToTensorUsingSwsScale(
int streamIndex,
const AVFrame* frame,
torch::Tensor& outputTensor) {
Expand All @@ -1361,50 +1420,6 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale(

int expectedOutputHeight = outputTensor.sizes()[0];
int expectedOutputWidth = outputTensor.sizes()[1];
auto curFrameSwsContextKey = SwsContextKey{
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight};
if (activeStream.swsContext.get() == nullptr ||
activeStream.swsContextKey != curFrameSwsContextKey) {
SwsContext* swsContext = sws_getContext(
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight,
AV_PIX_FMT_RGB24,
SWS_BILINEAR,
nullptr,
nullptr,
nullptr);
int* invTable = nullptr;
int* table = nullptr;
int srcRange, dstRange, brightness, contrast, saturation;
sws_getColorspaceDetails(
swsContext,
&invTable,
&srcRange,
&table,
&dstRange,
&brightness,
&contrast,
&saturation);
const int* colorspaceTable = sws_getCoefficients(frame->colorspace);
sws_setColorspaceDetails(
swsContext,
colorspaceTable,
srcRange,
colorspaceTable,
dstRange,
brightness,
contrast,
saturation);
activeStream.swsContextKey = curFrameSwsContextKey;
activeStream.swsContext.reset(swsContext);
}
SwsContext* swsContext = activeStream.swsContext.get();
uint8_t* pointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
Expand All @@ -1428,10 +1443,12 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
if (ffmpegStatus < AVSUCCESS) {
throw std::runtime_error("Failed to add frame to buffer source context");
}

UniqueAVFrame filteredFrame(av_frame_alloc());
ffmpegStatus =
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get());
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24);

auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredFrame.get());
int height = frameDims.height;
int width = frameDims.width;
Expand All @@ -1441,9 +1458,8 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
auto deleter = [filteredFramePtr](void*) {
UniqueAVFrame frameToDelete(filteredFramePtr);
};
torch::Tensor tensor = torch::from_blob(
return torch::from_blob(
filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
return tensor;
}

VideoDecoder::~VideoDecoder() {
Expand Down
20 changes: 12 additions & 8 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,14 @@ class VideoDecoder {
AVFilterContext* sourceContext = nullptr;
AVFilterContext* sinkContext = nullptr;
};
struct SwsContextKey {
struct DecodedFrameContext {
int decodedWidth;
int decodedHeight;
AVPixelFormat decodedFormat;
int outputWidth;
int outputHeight;
bool operator==(const SwsContextKey&);
bool operator!=(const SwsContextKey&);
int expectedWidth;
int expectedHeight;
bool operator==(const DecodedFrameContext&);
bool operator!=(const DecodedFrameContext&);
};
// Stores information for each stream.
struct StreamInfo {
Expand All @@ -342,7 +342,7 @@ class VideoDecoder {
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
std::vector<FrameInfo> keyFrames;
std::vector<FrameInfo> allFrames;
SwsContextKey swsContextKey;
DecodedFrameContext prevFrame;
UniqueSwsContext swsContext;
};
// Returns the key frame index of the presentation timestamp using FFMPEG's
Expand Down Expand Up @@ -371,10 +371,14 @@ class VideoDecoder {
void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex);
// Creates and initializes a filter graph for a stream. The filter graph can
// do rescaling and color conversion.
void initializeFilterGraph(
void createFilterGraph(
StreamInfo& streamInfo,
int expectedOutputHeight,
int expectedOutputWidth);
void createSwsContext(
StreamInfo& streamInfo,
const DecodedFrameContext& frameContext,
const enum AVColorSpace colorspace);
void maybeSeekToBeforeDesiredPts();
RawDecodedOutput getDecodedOutputWithFilter(
std::function<bool(int, AVFrame*)>);
Expand All @@ -389,7 +393,7 @@ class VideoDecoder {
torch::Tensor convertFrameToTensorUsingFilterGraph(
int streamIndex,
const AVFrame* frame);
int convertFrameToBufferUsingSwsScale(
int convertFrameToTensorUsingSwsScale(
int streamIndex,
const AVFrame* frame,
torch::Tensor& outputTensor);
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.4a0
0.1.2

0 comments on commit 6e9fb33

Please sign in to comment.