Skip to content

Commit

Permalink
Include png functions as part of the main library
Browse files Browse the repository at this point in the history
  • Loading branch information
andfoy committed Jun 11, 2020
1 parent f97a9f0 commit 10559e8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
46 changes: 24 additions & 22 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,22 @@ def write_version_file():


def get_extensions():
try:
include_dirs = [os.environ['LIBRARY_INC']]
except KeyError:
include_dirs = []

try:
library_dirs = [os.environ['LIBRARY_LIB']]
except KeyError:
library_dirs = []

this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
image_src = glob.glob(os.path(extensions_dir, 'cpu', 'image', '*.cpp'))

is_rocm_pytorch = False
if torch.__version__ >= '1.5':
Expand All @@ -104,7 +115,7 @@ def get_extensions():
else:
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

sources = main_file + source_cpu
sources = main_file + source_cpu + image_src
extension = CppExtension

compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
Expand Down Expand Up @@ -149,13 +160,14 @@ def get_extensions():

sources = [os.path.join(extensions_dir, s) for s in sources]

include_dirs = [extensions_dir]
include_dirs += [extensions_dir]

ext_modules = [
extension(
'torchvision._C',
sources,
include_dirs=include_dirs,
library_dirs=library_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
Expand All @@ -171,27 +183,17 @@ def get_extensions():
)
)

try:
include_dirs += [os.environ['LIBRARY_INC']]
except KeyError:
pass

try:
library_dirs = [os.environ['LIBRARY_LIB']]
except KeyError:
library_dirs = []

# Image reading extension
image_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'image')
image_src = glob.glob(os.path.join(image_src_dir, "*.cpp"))
ext_modules.append(extension(
'torchvision.image',
image_src,
include_dirs=include_dirs + image_src_dir,
library_dirs=library_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args
))
# image_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'image')
# image_src = glob.glob(os.path.join(image_src_dir, "*.cpp"))
# ext_modules.append(extension(
# 'torchvision.image',
# image_src,
# include_dirs=include_dirs + [image_src_dir],
# library_dirs=library_dirs,
# define_macros=define_macros,
# extra_compile_args=extra_compile_args
# ))

ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
has_ffmpeg = ffmpeg_exe is not None
Expand Down
3 changes: 3 additions & 0 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
write_video,
)

from .image import (read_png)


__all__ = [
"write_video",
Expand All @@ -31,4 +33,5 @@
"_read_video_meta_data",
"VideoMetaData",
"Timebase",
"read_png"
]

0 comments on commit 10559e8

Please sign in to comment.