-
Notifications
You must be signed in to change notification settings - Fork 2
Pythonic API
The API module exposes a thin wrapper SwAeController
over the original code to aid ease of use. It allows you to process your first images in about 10 lines of code.
In its simplest form, you need to specify the model to load and the path's to the two images you want to mix. Calling the compute()
function returns a tensor with the output image. In the example below, we have also used a Util function to convert that tensor to a PIL image
import matplotlib.pyplot as plt
from api import SwAeController
from api.util import tensor_to_PIL
SAE = SwAeController("mountain_pretrained")
SAE.set_tex(IMG_PATH)
SAE.mix_style(IMG_PATH_2, 0.75)
output_image = SAE.compute()
output_image = tensor_to_PIL(output_image[0])
plt.imshow(output_image)
plt.show()
SwAeController(name: str) -> None:
Initialize the model and other options
Args:
-
name
(str): Name of the pre-trained weight to load.
SwAeController.set_size(size: int) -> None:
Sets transform to load images with the size
. Output is also of width size
. It must be greater than 128 and must be a multiple of 4.
Args:
-
size
(int): Size of the output image
Raises:
-
ValueError
if the size is not a valid integer.
SwAeController.set_structure(structure_path: str) -> None:
Sets the structure image, must be called before compute(). Doesn't cache the image. But, sets the noise input for the model
Args:
-
structure_path
(str): path to the structure image
SwAeController.mix_style(style_path: str, alpha: float) -> None:
Mixes the style of the image given with the current structure image by the factor of alpha. Caches the encoded image. actual mixing happens when compute()
is called
Args:
-
style_path
(str): Path to the image whose style you want to mix -
alpha
(float): Value of mix factor. 0 would remove this image from the mix, 1 implies using NONE of the original styles
SwAeController.compute(self) -> torch.Tensor:
Computes the output of the operations performed by the mix_style and gives the output image. Returns:
- torch.Tensor: output tensor with the shape Tensor with shape (1, 3,
h
,w
) whereh
andw
are height and width.
Other functions were not documented as they are intended for internal use. You can learn more about them from their docstrings.
All functions in this class are wrapped by the timing
wrapper which can help in debugging output. It prints time taken by a function to run, the arguments passed to it, other function calls made, and if any errors were encountered.
By default debugging is turned off but you can enable it with the following code
from api.util import UtitlState
UtitlState.debug(True)