forked from google/flax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
74 lines (64 loc) · 3.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is created with reference to torchvision/utils.py.
Modify: torch.tensor -> jax.numpy.DeviceArray
If you want to know about this file in detail, please visit the original code:
https://github.com/pytorch/vision/blob/master/torchvision/utils.py
"""
import math
from PIL import Image
import jax
import jax.numpy as jnp
def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format=None):
"""Make a grid of images and Save it into an image file.
Args:
ndarray (array_like): 4D mini-batch images of shape (B x H x W x C)
fp - A filename(string) or file object
nrow (int, optional): Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding (int, optional): amount of padding. Default: ``2``.
scale_each (bool, optional): If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
"""
if not (isinstance(ndarray, jnp.ndarray) or
(isinstance(ndarray, list) and all(isinstance(t, jnp.ndarray) for t in ndarray))):
raise TypeError('array_like of tensors expected, got {}'.format(type(ndarray)))
ndarray = jnp.asarray(ndarray)
if ndarray.ndim == 4 and ndarray.shape[-1] == 1: # single-channel images
ndarray = jnp.concatenate((ndarray, ndarray, ndarray), -1)
# make the mini-batch of images into a grid
nmaps = ndarray.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(ndarray.shape[1] + padding), int(ndarray.shape[2] + padding)
num_channels = ndarray.shape[3]
grid = jnp.full((height * ymaps + padding, width * xmaps + padding, num_channels), pad_value).astype(jnp.float32)
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
grid = jax.ops.index_update(
grid, jax.ops.index[y * height + padding:(y + 1) * height,
x * width + padding:(x + 1) * width],
ndarray[k])
k = k + 1
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = jnp.clip(grid * 255.0 + 0.5, 0, 255).astype(jnp.uint8)
im = Image.fromarray(ndarr.copy())
im.save(fp, format=format)