Skip to content

Commit

Permalink
Add support for png decoding on linux
Browse files Browse the repository at this point in the history
  • Loading branch information
r-zenine committed Feb 13, 2020
1 parent e2573a7 commit d102d39
Show file tree
Hide file tree
Showing 14 changed files with 266 additions and 31 deletions.
7 changes: 6 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ commands:
description: "checkout merge branch"
steps:
- checkout
- run:
name: initialize submodules
command: git submodule update --init --recursive
# - run:
# name: Checkout merge branch
# command: |
Expand Down Expand Up @@ -83,6 +86,8 @@ jobs:
resource_class: 2xlarge+
steps:
- checkout_merge
- run:
command: yum install -yq zlib-devel
- run: packaging/build_wheel.sh
- store_artifacts:
path: dist
Expand Down Expand Up @@ -128,7 +133,7 @@ jobs:
ca-certificates \
curl \
gnupg-agent \
software-properties-common
software-properties-common
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
Expand Down
7 changes: 6 additions & 1 deletion .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ commands:
description: "checkout merge branch"
steps:
- checkout
- run:
name: initialize submodules
command: git submodule update --init --recursive
# - run:
# name: Checkout merge branch
# command: |
Expand Down Expand Up @@ -83,6 +86,8 @@ jobs:
resource_class: 2xlarge+
steps:
- checkout_merge
- run:
command: yum install -yq zlib-devel
- run: packaging/build_wheel.sh
- store_artifacts:
path: dist
Expand Down Expand Up @@ -128,7 +133,7 @@ jobs:
ca-certificates \
curl \
gnupg-agent \
software-properties-common
software-properties-common

curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -

Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/libpng"]
path = third_party/libpng
url = https://github.com/glennrp/libpng
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ matrix:

before_install:
- sudo apt-get update
- sudo apt-get install -y zlib1g-dev
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
Expand Down
19 changes: 16 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ if(WITH_CUDA)
add_definitions(-D__CUDA_NO_HALF_OPERATORS__)
endif()

if(Unix)
add_subdirectory("third_party/libpng")
endif()

find_package(Torch REQUIRED)
find_package(pybind11 REQUIRED)

Expand All @@ -21,8 +25,17 @@ endif()
file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h)
file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp)

add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} pybind11::pybind11)
file(GLOB IMAGE_HEADERS torchvision/csrc/image.h)
file(GLOB IMAGE_SOURCES torchvision/csrc/cpu/image/*.h torchvision/csrc/cpu/image/*.cpp)

if(Unix)
add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES} {IMAGE_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} pybind11::pybind11 "${PNG_LIBRARIES}")
else()
add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} pybind11::pybind11)
endif()

set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision)

target_include_directories(${PROJECT_NAME} INTERFACE
Expand All @@ -49,7 +62,7 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/TorchVisionConfig.cmake
install(TARGETS ${PROJECT_NAME}
EXPORT TorchVisionTargets)

install(EXPORT TorchVisionTargets
install(EXPORT TorchVisionTargets
NAMESPACE TorchVision::
DESTINATION ${TORCHVISION_CMAKECONFIG_INSTALL_DIR})

Expand Down
1 change: 1 addition & 0 deletions packaging/torchvision/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ source:
requirements:
build:
- {{ compiler('c') }} # [win]
- zlib

host:
- python
Expand Down
79 changes: 53 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,21 @@ def get_extensions():

main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
source_image_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', 'image', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

sources = main_file + source_cpu

libraries = []
extra_objects= []
extra_compile_args = {}
third_party_search_directories = []

if sys.platform.startswith('linux'):
sources = sources + source_image_cpu
libraries.append('png')
third_party_search_directories.append(os.path.join(cwd, "third_party/libpng"))

extension = CppExtension

compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
Expand All @@ -102,7 +114,6 @@ def get_extensions():

define_macros = []

extra_compile_args = {}
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
extension = CUDAExtension
sources += source_cuda
Expand Down Expand Up @@ -142,9 +153,12 @@ def get_extensions():
extension(
'torchvision._C',
sources,
include_dirs=include_dirs,
libraries= libraries,
library_dirs=third_party_search_directories,
include_dirs=include_dirs + third_party_search_directories,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_objects=extra_objects
)
]
if compile_cpp_tests:
Expand Down Expand Up @@ -196,29 +210,42 @@ def run(self):
# It's an old-style class in Python 2.7...
distutils.command.clean.clean.run(self)

def build_deps():
this_dir = os.path.dirname(os.path.abspath(__file__))
if sys.platform.startswith('linux'):
os.chdir("third_party/libpng/")
os.system('cmake .')
os.system("cmake --build .")
os.chdir(this_dir)



def build_ext_with_dependencies(self):
build_deps()
return BuildExtension.with_options(no_python_abi_suffix=True)(self)

setup(
# Metadata
name=package_name,
version=version,
author='PyTorch Core Team',
author_email='[email protected]',
url='https://github.com/pytorch/vision',
description='image and video datasets and models for torch deep learning',
long_description=readme,
license='BSD',

# Package info
packages=find_packages(exclude=('test',)),

zip_safe=False,
install_requires=requirements,
extras_require={
"scipy": ["scipy"],
},
ext_modules=get_extensions(),
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
'clean': clean,
}
)
# Metadata
name=package_name,
version=version,
author='PyTorch Core Team',
author_email='[email protected]',
url='https://github.com/pytorch/vision',
description='image and video datasets and models for torch deep learning',
long_description=readme,
license='BSD',
# Package info
packages=find_packages(exclude=('test',)),
zip_safe=False,
install_requires=requirements,
extras_require={
"scipy": ["scipy"],
},
ext_modules=get_extensions(),
cmdclass={
'build_ext': build_ext_with_dependencies,
'clean': clean,
}
)
40 changes: 40 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import unittest
import sys

import torch
from PIL import Image
if sys.platform.startswith('linux'):
from torchvision.io.image import read_png, decode_png
import numpy as np

IMAGE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "imagefolder")


def get_images(directory, img_ext):
assert os.path.isdir(directory)
for root, dir, files in os.walk(directory):
for fl in files:
_, ext = os.path.splitext(fl)
if ext == img_ext:
yield os.path.join(root, fl)


class ImageTester(unittest.TestCase):
@unittest.skipUnless(sys.platform.startswith("linux"), "Support only available on linux for now.")
def test_read_png(self):
for img_path in get_images(IMAGE_DIR, "png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_lpng = read_png(img_path)
self.assertEqual(img_lpng, img_pil)

@unittest.skipUnless(sys.platform.startswith("linux"), "Support only available on linux for now.")
def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, "png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
size = os.path.getsize(img_path)
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertEqual(img_lpng, img_pil)

if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions third_party/libpng
Submodule libpng added at 301f7a
75 changes: 75 additions & 0 deletions torchvision/csrc/cpu/image/readpng_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "readpng_cpu.h"

#include <png.h>
#include <setjmp.h>
#include <string>

torch::Tensor decodePNG(const torch::Tensor& data) {
auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
auto info_ptr = png_create_info_struct(png_ptr);
if (!info_ptr) {
png_destroy_read_struct(&png_ptr, nullptr, nullptr);
// Seems redundant with the if statement. done here to avoid leaking memory.
TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
}

auto datap = data.accessor<unsigned char, 1>().data();

if (setjmp(png_jmpbuf(png_ptr)) != 0) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Internal error.");
}
auto is_png = !png_sig_cmp(datap, 0, 8);
TORCH_CHECK(is_png, "Content is not png!")

struct Reader {
png_const_bytep ptr;
} reader;
reader.ptr = png_const_bytep(datap) + 8;

auto read_callback =
[](png_structp png_ptr, png_bytep output, png_size_t bytes) {
auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
std::copy(reader->ptr, reader->ptr + bytes, output);
reader->ptr += bytes;
};
png_set_sig_bytes(png_ptr, 8);
png_set_read_fn(png_ptr, &reader, read_callback);
png_read_info(png_ptr, info_ptr);

png_uint_32 width, height;
int bit_depth, color_type;
auto retval = png_get_IHDR(
png_ptr,
info_ptr,
&width,
&height,
&bit_depth,
&color_type,
nullptr,
nullptr,
nullptr);

if (retval != 1) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}
if (color_type != PNG_COLOR_TYPE_RGB) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(
color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.")
}

auto tensor =
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
for (decltype(height) i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += bytes;
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor;
}
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/image/readpng_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

#include <torch/torch.h>
#include <string>

torch::Tensor decodePNG(const torch::Tensor& data);
4 changes: 4 additions & 0 deletions torchvision/csrc/image.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#pragma once

#include "cpu/image/readpng_cpu.h"

6 changes: 6 additions & 0 deletions torchvision/csrc/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include "ROIAlign.h"
#include "ROIPool.h"
#include "empty_tensor_op.h"
#ifdef __linux__
#include "image.h"
#endif
#include "nms.h"

// If we are in a Windows environment, we need to define
Expand Down Expand Up @@ -49,4 +52,7 @@ static auto registry =
.op("torchvision::ps_roi_align", &ps_roi_align)
.op("torchvision::ps_roi_pool", &ps_roi_pool)
.op("torchvision::deform_conv2d", &deform_conv2d)
#ifdef __linux__
.op("torchvision::decode_png", &decodePNG)
#endif
.op("torchvision::_cuda_version", &_cuda_version);
Loading

0 comments on commit d102d39

Please sign in to comment.