Skip to content

Pythonic API

Abhishek Saxena edited this page Sep 26, 2021 · 3 revisions

Overview

The API module exposes a thin wrapper SwAeController over the original code to aid ease of use. It allows you to process your first images in about 10 lines of code.

In its simplest form, you need to specify the model to load and the path's to the two images you want to mix. Calling the compute() function returns a tensor with the output image. In the example below, we have also used a Util function to convert that tensor to a PIL image

import matplotlib.pyplot as plt
from api import SwAeController
from api.util import tensor_to_PIL

SAE = SwAeController("mountain_pretrained")
SAE.set_tex(IMG_PATH)
SAE.mix_style(IMG_PATH_2, 0.75)

output_image = SAE.compute()
output_image = tensor_to_PIL(output_image[0])
plt.imshow(output_image)
plt.show()

API reference

SwAeController(name: str) -> None:

Initialize the model and other options

Args:

  • name (str): Name of the pre-trained weight to load.
SwAeController.set_size(size: int) -> None:

Sets transform to load images with the size. Output is also of width size. It must be greater than 128 and must be a multiple of 4.

Args:

  • size (int): Size of the output image

Raises:

  • ValueError if the size is not a valid integer.
SwAeController.set_structure(structure_path: str) -> None:

Sets the structure image, must be called before compute(). Doesn't cache the image. But, sets the noise input for the model

Args:

  • structure_path (str): path to the structure image
SwAeController.mix_style(style_path: str, alpha: float) -> None:

Mixes the style of the image given with the current structure image by the factor of alpha. Caches the encoded image. actual mixing happens when compute() is called

Args:

  • style_path (str): Path to the image whose style you want to mix
  • alpha (float): Value of mix factor. 0 would remove this image from the mix, 1 implies using NONE of the original styles
SwAeController.compute(self) -> torch.Tensor:

Computes the output of the operations performed by the mix_style and gives the output image. Returns:

  • torch.Tensor: output tensor with the shape Tensor with shape (1, 3, h, w) where h and w are height and width.

Other functions were not documented as they are intended for internal use. You can learn more about them from their docstrings.

Misc

All functions in this class are wrapped by the timing wrapper which can help in debugging output. It prints time taken by a function to run, the arguments passed to it, other function calls made, and if any errors were encountered.

By default debugging is turned off but you can enable it with the following code

from api.util import UtitlState
UtitlState.debug(True)
Clone this wiki locally