diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index fa694276a357..da886ffe6f15 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -90,14 +90,14 @@ def forward(self, clip_input, images): else: images[idx] = np.zeros(images[idx].shape) # black image - #if any(has_nsfw_concepts): - # logger.warning( - # "Potential NSFW content was detected in one or more images. A black image will be returned instead." - # " Try again with a different prompt and/or seed." - # ) + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) - #return images, has_nsfw_concepts - return images, False + return images, has_nsfw_concepts + #return images, False @torch.no_grad() def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): @@ -118,10 +118,10 @@ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor) special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment - # concept_scores = concept_scores.round(decimals=3) - #has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) #images[has_nsfw_concepts] = 0.0 # black image - #return images, has_nsfw_concepts - return images, False + return images, has_nsfw_concepts + #return images, False