Skip to content

Commit

Permalink
Allow customization of number of preview images to display
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Jun 16, 2019
1 parent d5b64ef commit 5917cc4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
8 changes: 6 additions & 2 deletions plugins/train/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def set_globals(self):
"facial parts"
"\n\t dfl_full: An improved face hull mask using a facehull of 3 "
"facial parts"
"\n\t extended: Based on components mask. Extends the eyebrow points to "
"further up the forehead. May perform badly on difficult angles."
"\n\t extended: Based on components mask. Extends the eyebrow points "
"to further up the forehead. May perform badly on difficult angles."
"\n\t facehull: Face cutout based on landmarks")
self.add_item(
section=section, title="icnr_init", datatype=bool, default=False,
Expand All @@ -74,6 +74,10 @@ def set_globals(self):
info="If using a mask, This penalizes the loss for the masked area, to give higher "
"priority to the face area. \nShould increase overall quality and speed up "
"training. This should probably be left at True")
self.add_item(
section=section, title="preview_images", datatype=int, default=14, min_max=(2, 16),
rounding=2, fixed=False,
info="Number of sample faces to display for each side in the preview when training.")
logger.debug("Set global config")

def load_module(self, filename, module_path, plugin_type):
Expand Down
5 changes: 4 additions & 1 deletion plugins/train/model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,14 @@ def set_training_data(self):
super() this method for defaults otherwise be sure to add """
logger.debug("Setting training data")
# Force number of preview images to between 2 and 16
preview_images = self.config.get("preview_images", 14)
preview_images = min(max(preview_images, 2), 16)
self.training_opts["preview_images"] = preview_images
self.training_opts["training_size"] = self.state.training_size
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
self.training_opts["mask_type"] = self.config.get("mask_type", None)
self.training_opts["coverage_ratio"] = self.calculate_coverage_ratio()
self.training_opts["preview_images"] = 14
logger.debug("Set training data: %s", self.training_opts)

def calculate_coverage_ratio(self):
Expand Down

0 comments on commit 5917cc4

Please sign in to comment.