diff --git a/.circleci/config.yml b/.circleci/config.yml index 29b5fc77aab..f3fc23b7c92 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -107,6 +107,8 @@ jobs: - checkout - run: command: | + sudo apt-get update -y + sudo apt install -y libturbojpeg-dev pip install --user --progress-bar off numpy mypy pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --editable . diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index d9bd257eae6..f63c3f408ba 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -107,6 +107,8 @@ jobs: - checkout - run: command: | + sudo apt-get update -y + sudo apt install -y libturbojpeg-dev pip install --user --progress-bar off numpy mypy pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --editable . diff --git a/.travis.yml b/.travis.yml index ec25bdb8677..0757ad3c159 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,6 +13,7 @@ jobs: before_install: - sudo apt-get update + - sudo apt-get install -y libpng-dev libturbojpeg-dev - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" diff --git a/CMakeLists.txt b/CMakeLists.txt index fa50f155ce4..0d7e7aaa9d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,22 +11,28 @@ if(WITH_CUDA) endif() find_package(Python3 COMPONENTS Development) + find_package(Torch REQUIRED) +find_package(PNG REQUIRED) + file(GLOB HEADERS torchvision/csrc/*.h) -file(GLOB OPERATOR_SOURCES torchvision/csrc/cpu/*.h torchvision/csrc/cpu/*.cpp torchvision/csrc/*.cpp) +file(GLOB IMAGE_HEADERS torchvision/csrc/cpu/image/*.h) +file(GLOB IMAGE_SOURCES torchvision/csrc/cpu/image/*.cpp) +file(GLOB OPERATOR_SOURCES torchvision/csrc/cpu/*.h torchvision/csrc/cpu/*.cpp ${IMAGE_HEADERS} ${IMAGE_SOURCES} ${HEADERS} torchvision/csrc/*.cpp) if(WITH_CUDA) file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} torchvision/csrc/cuda/*.h torchvision/csrc/cuda/*.cu) endif() file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h) file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp) -add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES}) -target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} Python3::Python) +add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES} ${IMAGE_SOURCES}) +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${PNG_LIBRARY} Python3::Python) +# target_link_libraries(${PROJECT_NAME} PRIVATE ${PNG_LIBRARY} Python3::Python) set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision) target_include_directories(${PROJECT_NAME} INTERFACE - $ + $ $) include(GNUInstallDirs) @@ -61,7 +67,7 @@ install(FILES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu) if(WITH_CUDA) install(FILES - torchvision/csrc/cuda/vision_cuda.h + torchvision/csrc/cuda/vision_cuda.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cuda) endif() install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/models) diff --git a/README.rst b/README.rst index 3150a3023ad..d130ad03a78 100644 --- a/README.rst +++ b/README.rst @@ -78,13 +78,22 @@ Torchvision currently supports the following image backends: * `accimage`_ - if installed can be activated by calling :code:`torchvision.set_image_backend('accimage')` +* `libpng`_ - can be installed via conda :code:`conda install libpng` or any of the package managers for debian-based and RHEL-based Linux distributions. + +* `libturbojpeg`_ - blazing speed, fast JPEG image loading. Can be installed from conda-forge :code:`conda install libjpeg-turbo -c conda-forge`. + +**Notes:** ``libpng`` and ``libturbojpeg`` must be available at compilation time in order to be available. Also, most linux distributions distinguish between +``libturbojpeg`` and ``libjpeg-turbo``, where the former should be installed instead of the latter one. + +.. _libpng : http://www.libpng.org/pub/png/libpng.html +.. _libturbojpeg: https://github.com/libjpeg-turbo/libjpeg-turbo .. _Pillow : https://python-pillow.org/ .. _Pillow-SIMD : https://github.com/uploadcare/pillow-simd .. _accimage: https://github.com/pytorch/accimage C++ API ======= -TorchVision also offers a C++ API that contains C++ equivalent of python models. +TorchVision also offers a C++ API that contains C++ equivalent of python models. Installation From source: @@ -94,7 +103,7 @@ Installation From source: cd build # Add -DWITH_CUDA=on support for the CUDA if needed cmake .. - make + make make install Once installed, the library can be accessed in cmake (after properly configuring ``CMAKE_PREFIX_PATH``) via the :code:`TorchVision::TorchVision` target: diff --git a/packaging/build_wheel.sh b/packaging/build_wheel.sh index a075b3b3a00..30ace73d231 100755 --- a/packaging/build_wheel.sh +++ b/packaging/build_wheel.sh @@ -10,6 +10,30 @@ setup_wheel_python pip_install numpy pyyaml future ninja setup_pip_pytorch_version python setup.py clean + +# Copy binaries to be included in the wheel distribution +if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then + python_exec="$(which python)" + bin_path=$(dirname $python_exec) + env_path=$(dirname $bin_path) + if [[ "$(uname)" == Darwin ]]; then + # Include LibPNG + cp "$env_path/lib/libpng16.dylib" torchvision + # Include TurboJPEG + cp "$env_path/lib/libturbojpeg.dylib" torchvision + else + # Include libPNG + cp "$bin_path/Library/lib/libpng.lib" torchvision + # Include TurboJPEG + cp "$bin_path/Library/lib/turbojpeg.lib" torchvision + fi +else + # Include LibPNG + cp "/usr/lib64/libpng.so" torchvision + # Include TurboJPEG + cp "/usr/lib64/libturbojpeg.so" torchvision +fi + if [[ "$OSTYPE" == "msys" ]]; then IS_WHEEL=1 "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel else diff --git a/packaging/pkg_helpers.bash b/packaging/pkg_helpers.bash index b262a0f5157..88fe52c0b08 100644 --- a/packaging/pkg_helpers.bash +++ b/packaging/pkg_helpers.bash @@ -170,7 +170,13 @@ setup_wheel_python() { conda env remove -n "env$PYTHON_VERSION" || true conda create -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION" conda activate "env$PYTHON_VERSION" + # Install libPNG from Anaconda (defaults) + conda install libpng -y + # Install libJPEG-turbo from conda-forge + conda install -y libjpeg-turbo -c conda-forge else + # Install native CentOS libPNG, libJPEG-turbo + yum install -y libpng-devel turbojpeg-devel case "$PYTHON_VERSION" in 2.7) if [[ -n "$UNICODE_ABI" ]]; then diff --git a/packaging/torchvision/conda_build_config.yaml b/packaging/torchvision/conda_build_config.yaml index 5188bb0ebec..adbbc732a0c 100644 --- a/packaging/torchvision/conda_build_config.yaml +++ b/packaging/torchvision/conda_build_config.yaml @@ -1,3 +1,6 @@ +channel_sources: + - defaults,conda-forge + blas_impl: - mkl # [x86_64] c_compiler: diff --git a/packaging/torchvision/meta.yaml b/packaging/torchvision/meta.yaml index 7d6f28cdf3c..62cfc401e3c 100644 --- a/packaging/torchvision/meta.yaml +++ b/packaging/torchvision/meta.yaml @@ -8,6 +8,8 @@ source: requirements: build: - {{ compiler('c') }} # [win] + - libpng + - libjpeg-turbo host: - python @@ -18,6 +20,10 @@ requirements: run: - python + - libpng + - libjpeg-turbo + # Pillow introduces unwanted conflicts with libjpeg-turbo, as it depends on jpeg + # The fix depends on https://github.com/conda-forge/conda-forge.github.io/issues/673 - pillow >=4.1.1 - numpy >=1.11 {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} diff --git a/setup.py b/setup.py index 0620193b3a7..e454c59e047 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import re import sys from setuptools import setup, find_packages -from pkg_resources import get_distribution, DistributionNotFound +from pkg_resources import parse_version, get_distribution, DistributionNotFound import subprocess import distutils.command.clean import distutils.spawn @@ -76,7 +76,76 @@ def write_version_file(): requirements.append(pillow_req + pillow_ver) +def find_library(name, vision_include): + this_dir = os.path.dirname(os.path.abspath(__file__)) + build_prefix = os.environ.get('BUILD_PREFIX', None) + is_conda_build = build_prefix is not None + + library_found = False + conda_installed = False + lib_folder = None + include_folder = None + library_header = '{0}.h'.format(name) + + print('Running build on conda-build: {0}'.format(is_conda_build)) + if is_conda_build: + # Add conda headers/libraries + if os.name == 'nt': + build_prefix = os.path.join(build_prefix, 'Library') + include_folder = os.path.join(build_prefix, 'include') + lib_folder = os.path.join(build_prefix, 'lib') + library_header_path = os.path.join( + include_folder, library_header) + library_found = os.path.isfile(library_header_path) + conda_installed = library_found + else: + # Check if using Anaconda to produce wheels + conda = distutils.spawn.find_executable('conda') + is_conda = conda is not None + print('Running build on conda: {0}'.format(is_conda)) + if is_conda: + python_executable = sys.executable + py_folder = os.path.dirname(python_executable) + if os.name == 'nt': + env_path = os.path.join(py_folder, 'Library') + else: + env_path = os.path.dirname(py_folder) + lib_folder = os.path.join(env_path, 'lib') + include_folder = os.path.join(env_path, 'include') + library_header_path = os.path.join( + include_folder, library_header) + library_found = os.path.isfile(library_header_path) + conda_installed = library_found + + # Try to locate turbojpeg in Linux standard paths + if not library_found: + if sys.platform == 'linux': + library_found = os.path.exists('/usr/include/{0}'.format( + library_header)) + library_found = library_found or os.path.exists( + '/usr/local/include/{0}'.format(library_header)) + else: + # Lookup in TORCHVISION_INCLUDE or in the package file + package_path = [os.path.join(this_dir, 'torchvision')] + for folder in vision_include + package_path: + candidate_path = os.path.join(folder, library_header) + library_found = os.path.exists(candidate_path) + if library_found: + break + + return library_found, conda_installed, include_folder, lib_folder + + def get_extensions(): + vision_include = os.environ.get('TORCHVISION_INCLUDE', None) + vision_library = os.environ.get('TORCHVISION_LIBRARY', None) + vision_include = (vision_include.split(os.pathsep) + if vision_include is not None else []) + vision_library = (vision_library.split(os.pathsep) + if vision_library is not None else []) + include_dirs = vision_include + library_dirs = vision_library + this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc') @@ -149,13 +218,14 @@ def get_extensions(): sources = [os.path.join(extensions_dir, s) for s in sources] - include_dirs = [extensions_dir] + include_dirs += [extensions_dir] ext_modules = [ extension( 'torchvision._C', sources, include_dirs=include_dirs, + library_dirs=library_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) @@ -171,6 +241,65 @@ def get_extensions(): ) ) + # Image reading extension + image_macros = [] + image_include = [extensions_dir] + image_library = [] + image_link_flags = [] + + # Locating libPNG + libpng = distutils.spawn.find_executable('libpng-config') + png_found = libpng is not None + image_macros += [('PNG_FOUND', str(int(png_found)))] + print('PNG found: {0}'.format(png_found)) + if png_found: + png_version = subprocess.run([libpng, '--version'], + stdout=subprocess.PIPE) + png_version = png_version.stdout.strip().decode('utf-8') + print('libpng version: {0}'.format(png_version)) + png_version = parse_version(png_version) + if png_version >= parse_version("1.6.0"): + print('Building torchvision with PNG image support') + png_lib = subprocess.run([libpng, '--libdir'], + stdout=subprocess.PIPE) + png_include = subprocess.run([libpng, '--I_opts'], + stdout=subprocess.PIPE) + image_library += [png_lib.stdout.strip().decode('utf-8')] + image_include += [png_include.stdout.strip().decode('utf-8')] + image_link_flags.append('png' if os.name != 'nt' else 'libpng') + else: + print('libpng installed version is less than 1.6.0, ' + 'disabling PNG support') + png_found = False + + # Locating libjpegturbo + turbojpeg_info = find_library('turbojpeg', vision_include) + (turbojpeg_found, conda_installed, + turbo_include_folder, turbo_lib_folder) = turbojpeg_info + + image_macros += [('JPEG_FOUND', str(int(turbojpeg_found)))] + print('turboJPEG found: {0}'.format(turbojpeg_found)) + if turbojpeg_found: + print('Building torchvision with JPEG image support') + image_link_flags.append('turbojpeg') + if conda_installed: + image_library += [turbo_lib_folder] + image_include += [turbo_include_folder] + + image_path = os.path.join(extensions_dir, 'cpu', 'image') + image_src = glob.glob(os.path.join(image_path, '*.cpp')) + + if png_found or turbojpeg_found: + ext_modules.append(extension( + 'torchvision.image', + image_src, + include_dirs=include_dirs + [image_path] + image_include, + library_dirs=library_dirs + image_library, + define_macros=image_macros, + libraries=image_link_flags, + extra_compile_args=extra_compile_args + )) + ffmpeg_exe = distutils.spawn.find_executable('ffmpeg') has_ffmpeg = ffmpeg_exe is not None @@ -243,7 +372,9 @@ def run(self): # Package info packages=find_packages(exclude=('test',)), - + package_data={ + package_name: ['*.lib', '*.dylib', '*.so'] + }, zip_safe=False, install_requires=requirements, extras_require={ diff --git a/test/test_image.py b/test/test_image.py new file mode 100644 index 00000000000..4ac331f399f --- /dev/null +++ b/test/test_image.py @@ -0,0 +1,72 @@ +import os +import unittest +import sys + +import torch +import torchvision +from PIL import Image +from torchvision.io.image import read_png, decode_png, read_jpeg, decode_jpeg +import numpy as np + +IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") +IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder") + + +def get_images(directory, img_ext): + assert os.path.isdir(directory) + for root, _, files in os.walk(directory): + for fl in files: + _, ext = os.path.splitext(fl) + if ext == img_ext: + yield os.path.join(root, fl) + + +class ImageTester(unittest.TestCase): + def test_read_jpeg(self): + for img_path in get_images(IMAGE_ROOT, "jpg"): + img_pil = torch.from_numpy(np.array(Image.open(img_path))) + img_ljpeg = read_jpeg(img_path) + + norm = img_ljpeg.shape[0] * img_ljpeg.shape[1] * img_ljpeg.shape[2] * 255 + err = torch.abs(img_ljpeg.flatten().float() - img_pil.flatten().float()).sum().float() / (norm) + self.assertLessEqual(err, 1e-2) + + def test_decode_jpeg(self): + for img_path in get_images(IMAGE_ROOT, "jpg"): + img_pil = torch.from_numpy(np.array(Image.open(img_path))) + size = os.path.getsize(img_path) + img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size)) + + norm = img_ljpeg.shape[0] * img_ljpeg.shape[1] * img_ljpeg.shape[2] * 255 + err = torch.abs(img_ljpeg.flatten().float() - img_pil.flatten().float()).sum().float() / (norm) + + self.assertLessEqual(err, 1e-2) + + with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."): + decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) + + with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."): + decode_jpeg(torch.empty((100, ), dtype=torch.float16)) + + with self.assertRaisesRegex(RuntimeError, "Error while reading jpeg headers"): + decode_jpeg(torch.empty((100), dtype=torch.uint8)) + + def test_read_png(self): + for img_path in get_images(IMAGE_DIR, "png"): + img_pil = torch.from_numpy(np.array(Image.open(img_path))) + img_lpng = read_png(img_path) + self.assertEqual(img_lpng, img_pil) + + def test_decode_png(self): + for img_path in get_images(IMAGE_DIR, "png"): + img_pil = torch.from_numpy(np.array(Image.open(img_path))) + size = os.path.getsize(img_path) + img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) + self.assertEqual(img_lpng, img_pil) + + self.assertEqual(decode_png(torch.empty()), torch.empty()) + self.assertEqual(decode_png(torch.randint(3, 5, (300,))), torch.empty()) + + +if __name__ == '__main__': + unittest.main() diff --git a/torchvision/csrc/cpu/image/image.cpp b/torchvision/csrc/cpu/image/image.cpp new file mode 100644 index 00000000000..33f492ae683 --- /dev/null +++ b/torchvision/csrc/cpu/image/image.cpp @@ -0,0 +1,24 @@ + +#include "image.h" +#include +#include + +// If we are in a Windows environment, we need to define +// initialization functions for the _custom_ops extension +#ifdef _WIN32 +#if PY_MAJOR_VERSION < 3 +PyMODINIT_FUNC init_image(void) { + // No need to do anything. + return NULL; +} +#else +PyMODINIT_FUNC PyInit_image(void) { + // No need to do anything. + return NULL; +} +#endif +#endif + +static auto registry = torch::RegisterOperators() + .op("image::decode_png", &decodePNG) + .op("image::decode_jpeg", &decodeJPEG); diff --git a/torchvision/csrc/cpu/image/image.h b/torchvision/csrc/cpu/image/image.h new file mode 100644 index 00000000000..6db97d4729a --- /dev/null +++ b/torchvision/csrc/cpu/image/image.h @@ -0,0 +1,7 @@ + +#pragma once + +#include +#include +#include "readjpeg_cpu.h" +#include "readpng_cpu.h" diff --git a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp new file mode 100644 index 00000000000..1bf700a2317 --- /dev/null +++ b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp @@ -0,0 +1,56 @@ +#include "readjpeg_cpu.h" + +#include +#include +#include + +#if !JPEG_FOUND + +torch::Tensor decodeJPEG(const torch::Tensor& data) { + AT_ERROR("decodeJPEG: torchvision not compiled with turboJPEG support"); +} + +#else +#include + +torch::Tensor decodeJPEG(const torch::Tensor& data) { + tjhandle tjInstance = tjInitDecompress(); + if (tjInstance == NULL) { + TORCH_CHECK(false, "libjpeg-turbo decompression initialization failed."); + } + + auto datap = data.accessor().data(); + + int width, height; + + if (tjDecompressHeader(tjInstance, datap, data.numel(), &width, &height) < + 0) { + tjDestroy(tjInstance); + TORCH_CHECK(false, "Error while reading jpeg headers"); + } + auto tensor = + torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8); + + auto ptr = tensor.accessor().data(); + + int pixelFormat = TJPF_RGB; + + auto ret = tjDecompress2( + tjInstance, + datap, + data.numel(), + ptr, + width, + 0, + height, + pixelFormat, + NULL); + if (ret != 0) { + tjDestroy(tjInstance); + TORCH_CHECK(false, "decompressing JPEG image"); + } + + return tensor; +} + +#endif // JPEG_FOUND diff --git a/torchvision/csrc/cpu/image/readjpeg_cpu.h b/torchvision/csrc/cpu/image/readjpeg_cpu.h new file mode 100644 index 00000000000..40404df29b5 --- /dev/null +++ b/torchvision/csrc/cpu/image/readjpeg_cpu.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +torch::Tensor decodeJPEG(const torch::Tensor& data); diff --git a/torchvision/csrc/cpu/image/readpng_cpu.cpp b/torchvision/csrc/cpu/image/readpng_cpu.cpp new file mode 100644 index 00000000000..b7fb28e2575 --- /dev/null +++ b/torchvision/csrc/cpu/image/readpng_cpu.cpp @@ -0,0 +1,83 @@ +#include "readpng_cpu.h" + +#include +#include +#include + +#if !PNG_FOUND +torch::Tensor decodePNG(const torch::Tensor& data) { + AT_ERROR("decodePNG: torchvision not compiled with libPNG support"); +} +#else +#include + +torch::Tensor decodePNG(const torch::Tensor& data) { + auto png_ptr = + png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); + TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") + auto info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) { + png_destroy_read_struct(&png_ptr, nullptr, nullptr); + // Seems redundant with the if statement. done here to avoid leaking memory. + TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") + } + + auto datap = data.accessor().data(); + + if (setjmp(png_jmpbuf(png_ptr)) != 0) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, "Internal error."); + } + auto is_png = !png_sig_cmp(datap, 0, 8); + TORCH_CHECK(is_png, "Content is not png!") + + struct Reader { + png_const_bytep ptr; + } reader; + reader.ptr = png_const_bytep(datap) + 8; + + auto read_callback = + [](png_structp png_ptr, png_bytep output, png_size_t bytes) { + auto reader = static_cast(png_get_io_ptr(png_ptr)); + std::copy(reader->ptr, reader->ptr + bytes, output); + reader->ptr += bytes; + }; + png_set_sig_bytes(png_ptr, 8); + png_set_read_fn(png_ptr, &reader, read_callback); + png_read_info(png_ptr, info_ptr); + + png_uint_32 width, height; + int bit_depth, color_type; + auto retval = png_get_IHDR( + png_ptr, + info_ptr, + &width, + &height, + &bit_depth, + &color_type, + nullptr, + nullptr, + nullptr); + + if (retval != 1) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(retval == 1, "Could read image metadata from content.") + } + if (color_type != PNG_COLOR_TYPE_RGB) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK( + color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.") + } + + auto tensor = + torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8); + auto ptr = tensor.accessor().data(); + auto bytes = png_get_rowbytes(png_ptr, info_ptr); + for (decltype(height) i = 0; i < height; ++i) { + png_read_row(png_ptr, ptr, nullptr); + ptr += bytes; + } + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + return tensor; +} +#endif // PNG_FOUND diff --git a/torchvision/csrc/cpu/image/readpng_cpu.h b/torchvision/csrc/cpu/image/readpng_cpu.h new file mode 100644 index 00000000000..d2151a43aa9 --- /dev/null +++ b/torchvision/csrc/cpu/image/readpng_cpu.h @@ -0,0 +1,6 @@ +#pragma once + +#include +#include + +torch::Tensor decodePNG(const torch::Tensor& data); diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index cbbf560412e..4c47d8a51d5 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -30,5 +30,5 @@ "_read_video_clip_from_memory", "_read_video_meta_data", "VideoMetaData", - "Timebase", + "Timebase" ] diff --git a/torchvision/io/image.py b/torchvision/io/image.py new file mode 100644 index 00000000000..60d44380356 --- /dev/null +++ b/torchvision/io/image.py @@ -0,0 +1,109 @@ +import torch +from torch import nn, Tensor + +import os +import os.path as osp +import importlib + +_HAS_IMAGE_OPT = False + +try: + lib_dir = osp.join(osp.dirname(__file__), "..") + + loader_details = ( + importlib.machinery.ExtensionFileLoader, + importlib.machinery.EXTENSION_SUFFIXES + ) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec("image") + if ext_specs is not None: + torch.ops.load_library(ext_specs.origin) + _HAS_IMAGE_OPT = True +except (ImportError, OSError): + pass + + +def decode_png(input): + # type: (Tensor) -> Tensor + """ + Decodes a PNG image into a 3 dimensional RGB Tensor. + The values of the output tensor are uint8 between 0 and 255. + + Arguments: + input (Tensor[1]): a one dimensional int8 tensor containing + the raw bytes of the PNG image. + + Returns: + output (Tensor[image_width, image_height, 3]) + """ + if not isinstance(input, torch.Tensor) or len(input) == 0: + raise ValueError("Expected a non empty 1-dimensional tensor.") + + if not input.dtype == torch.uint8: + raise ValueError("Expected a torch.uint8 tensor.") + output = torch.ops.image.decode_png(input) + return output + + +def read_png(path): + # type: (str) -> Tensor + """ + Reads a PNG image into a 3 dimensional RGB Tensor. + The values of the output tensor are uint8 between 0 and 255. + + Arguments: + path (str): path of the PNG image. + + Returns: + output (Tensor[image_width, image_height, 3]) + """ + if not os.path.isfile(path): + raise ValueError("Expected a valid file path.") + + size = os.path.getsize(path) + if size == 0: + raise ValueError("Expected a non empty file.") + data = torch.from_file(path, dtype=torch.uint8, size=size) + return decode_png(data) + + +def decode_jpeg(input): + # type: (Tensor) -> Tensor + """ + Decodes a JPEG image into a 3 dimensional RGB Tensor. + The values of the output tensor are uint8 between 0 and 255. + Arguments: + input (Tensor[1]): a one dimensional int8 tensor containing + the raw bytes of the JPEG image. + Returns: + output (Tensor[image_width, image_height, 3]) + """ + if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: + raise ValueError("Expected a non empty 1-dimensional tensor.") + + if not input.dtype == torch.uint8: + raise ValueError("Expected a torch.uint8 tensor.") + + output = torch.ops.image.decode_jpeg(input) + return output + + +def read_jpeg(path): + # type: (str) -> Tensor + """ + Reads a JPEG image into a 3 dimensional RGB Tensor. + The values of the output tensor are uint8 between 0 and 255. + Arguments: + path (str): path of the JPEG image. + Returns: + output (Tensor[image_width, image_height, 3]) + """ + if not os.path.isfile(path): + raise ValueError("Expected a valid file path.") + + size = os.path.getsize(path) + if size == 0: + raise ValueError("Expected a non empty file.") + data = torch.from_file(path, dtype=torch.uint8, size=size) + return decode_jpeg(data)