Skip to content

Commit

Permalink
[spark] Fixes python tarslip security concern (deepjavalibrary#2995)
Browse files Browse the repository at this point in the history
* [spark] Fixes python tarslip security concern

* reformat python code
  • Loading branch information
frankfliu authored Feb 16, 2024
1 parent d825cf7 commit 1fcca33
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 70 deletions.
136 changes: 68 additions & 68 deletions basicdataset/src/main/resources/imagenet/extract_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_VAL_TAR = 'ILSVRC2012_img_val.tar'
_VAL_TAR_SHA1 = '5f3f73da3395154b60528b2b2a2caf2374f5f178'


def download(url, path=None, overwrite=False, sha1_hash=None):
"""Download an given URL
Parameters
Expand Down Expand Up @@ -42,26 +43,29 @@ def download(url, path=None, overwrite=False, sha1_hash=None):
else:
fname = path

if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
if overwrite or not os.path.exists(fname) or (
sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)

print('Downloading %s from %s...'%(fname, url))
print('Downloading %s from %s...' % (fname, url))
r = requests.get(url, stream=True)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
raise RuntimeError("Failed downloading url %s" % url)
total_length = r.headers.get('content-length')
with open(fname, 'wb') as f:
if total_length is None: # no content length header
if total_length is None: # no content length header
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
if chunk: # filter out keep-alive new chunks
f.write(chunk)
else:
total_length = int(total_length)
for chunk in tqdm(r.iter_content(chunk_size=1024),
total=int(total_length / 1024. + 0.5),
unit='KB', unit_scale=False, dynamic_ncols=True):
unit='KB',
unit_scale=False,
dynamic_ncols=True):
f.write(chunk)

if sha1_hash and not check_sha1(fname, sha1_hash):
Expand All @@ -72,25 +76,34 @@ def download(url, path=None, overwrite=False, sha1_hash=None):

return fname


def parse_args():
parser = argparse.ArgumentParser(
description='Setup the ImageNet dataset.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', required=True,
help="The directory that contains downloaded tar files")
parser.add_argument(
'--download-dir',
required=True,
help="The directory that contains downloaded tar files")
parser.add_argument('--target-dir',
help="The directory to store extracted images")
parser.add_argument('--checksum', action='store_true',
parser.add_argument('--checksum',
action='store_true',
help="If check integrity before extracting.")
parser.add_argument('--with-rec', action='store_true',
parser.add_argument('--with-rec',
action='store_true',
help="If build image record files.")
parser.add_argument('--num-thread', type=int, default=1,
help="Number of threads to use when building image record file.")
parser.add_argument(
'--num-thread',
type=int,
default=1,
help="Number of threads to use when building image record file.")
args = parser.parse_args()
if args.target_dir is None:
args.target_dir = args.download_dir
return args


def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Expand All @@ -116,11 +129,13 @@ def check_sha1(filename, sha1_hash):

return sha1.hexdigest() == sha1_hash


def check_file(filename, checksum, sha1):
if not os.path.exists(filename):
raise ValueError('File not found: '+filename)
raise ValueError('File not found: ' + filename)
if checksum and not check_sha1(filename, sha1):
raise ValueError('Corrupted file: '+filename)
raise ValueError('Corrupted file: ' + filename)


def build_rec_process(img_dir, train=False, num_thread=1):
rec_dir = os.path.abspath(os.path.join(img_dir, '../rec'))
Expand All @@ -141,102 +156,84 @@ def build_rec_process(img_dir, train=False, num_thread=1):
# execution
import sys
cmd = [
sys.executable,
script_path,
rec_dir,
img_dir,
'--recursive',
'--pass-through',
'--pack-label',
'--num-thread',
sys.executable, script_path, rec_dir, img_dir, '--recursive',
'--pass-through', '--pack-label', '--num-thread',
str(num_thread)
]
subprocess.call(cmd)
os.remove(script_path)
os.remove(lst_path)
print('ImageRecord file for ' + prefix + ' has been built!')


def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory


def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)


def extract_train(tar_fname, target_dir, with_rec=False, num_thread=1):
os.makedirs(target_dir)
with tarfile.open(tar_fname) as tar:
print("Extracting "+tar_fname+"...")
print("Extracting " + tar_fname + "...")
# extract each class one-by-one
pbar = tqdm(total=len(tar.getnames()))
for class_tar in tar:
pbar.set_description('Extract '+class_tar.name)
tar.extract(class_tar, target_dir)
pbar.set_description('Extract ' + class_tar.name)
class_fname = os.path.join(target_dir, class_tar.name)
if not is_within_directory(target_dir, class_fname):
raise Exception("Attempted Path Traversal in Tar File")

tar.extract(class_tar, target_dir)
class_dir = os.path.splitext(class_fname)[0]
os.mkdir(class_dir)
with tarfile.open(class_fname) as f:
def is_within_directory(directory, target):

abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

prefix = os.path.commonprefix([abs_directory, abs_target])

return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")

tar.extractall(path, members, numeric_owner=numeric_owner)


safe_extract(f, class_dir)

os.remove(class_fname)
pbar.update(1)
pbar.close()
if with_rec:
build_rec_process(target_dir, True, num_thread)


def extract_val(tar_fname, target_dir, with_rec=False, num_thread=1):
os.makedirs(target_dir)
print('Extracting ' + tar_fname)
with tarfile.open(tar_fname) as tar:
def is_within_directory(directory, target):

abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

prefix = os.path.commonprefix([abs_directory, abs_target])

return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")

tar.extractall(path, members, numeric_owner=numeric_owner)


safe_extract(tar, target_dir)

# build rec file before images are moved into subfolders
if with_rec:
build_rec_process(target_dir, False, num_thread)
# move images to proper subfolders
val_maps_file = os.path.join(os.path.dirname(__file__), 'imagenet_val_maps.pklz')
val_maps_file = os.path.join(os.path.dirname(__file__),
'imagenet_val_maps.pklz')
with gzip.open(val_maps_file, 'rb') as f:
dirs, mappings = pickle.load(f)
for d in dirs:
os.makedirs(os.path.join(target_dir, d))
for m in mappings:
os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0]))
os.rename(os.path.join(target_dir, m[0]),
os.path.join(target_dir, m[1], m[0]))


def main():
args = parse_args()

target_dir = os.path.expanduser(args.target_dir)
if os.path.exists(target_dir):
raise ValueError('Target dir ['+target_dir+'] exists. Remove it first')
raise ValueError('Target dir [' + target_dir +
'] exists. Remove it first')

download_dir = os.path.expanduser(args.download_dir)
train_tar_fname = os.path.join(download_dir, _TRAIN_TAR)
Expand All @@ -247,8 +244,11 @@ def main():
build_rec = args.with_rec
if build_rec:
os.makedirs(os.path.join(target_dir, 'rec'))
extract_train(train_tar_fname, os.path.join(target_dir, 'train'), build_rec, args.num_thread)
extract_val(val_tar_fname, os.path.join(target_dir, 'val'), build_rec, args.num_thread)
extract_train(train_tar_fname, os.path.join(target_dir, 'train'),
build_rec, args.num_thread)
extract_val(val_tar_fname, os.path.join(target_dir, 'val'), build_rec,
args.num_thread)


if __name__ == '__main__':
main()
19 changes: 17 additions & 2 deletions extensions/spark/setup/djl_spark/util/files_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ def download_and_extract(url, path):
:param url: The url of the tar file.
:param path: The path to the file to download to.
"""

def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")

tar.extractall(path, members, numeric_owner=numeric_owner)

if not os.path.exists(path):
os.makedirs(path)
if not os.listdir(path):
Expand All @@ -78,9 +93,9 @@ def download_and_extract(url, path):
if url.startswith("s3://"):
s3_download(url, tmp_file)
with tarfile.open(name=tmp_file, mode="r:gz") as t:
t.extractall(path=path)
safe_extract(t, path=path)
elif url.startswith("http://") or url.startswith("https://"):
with urlopen(url) as response, open(tmp_file, 'wb') as f:
shutil.copyfileobj(response, f)
with tarfile.open(name=tmp_file, mode="r:gz") as t:
t.extractall(path=path)
safe_extract(t, path=path)

0 comments on commit 1fcca33

Please sign in to comment.