From 58941a52ffd496746b9f890a139b756e04f94fac Mon Sep 17 00:00:00 2001 From: Art Gourieff <85128026+Gourieff@users.noreply.github.com> Date: Mon, 19 Jun 2023 02:54:54 +0700 Subject: [PATCH] Update cimage.py --- scripts/cimage.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/scripts/cimage.py b/scripts/cimage.py index fcf904e..c5891e3 100644 --- a/scripts/cimage.py +++ b/scripts/cimage.py @@ -1,18 +1,10 @@ from typing import List, Union, Dict, Set, Tuple -# from diffusers.pipelines.stable_diffusion.safety_checker import ( -# StableDiffusionSafetyChecker, -# ) from transformers import AutoFeatureExtractor import torch from PIL import Image, ImageFilter import numpy as np -# safety_model_id: str = "CompVis/stable-diffusion-safety-checker" -# safety_feature_extractor: AutoFeatureExtractor = None -# safety_checker: StableDiffusionSafetyChecker = None - - def numpy_to_pil(images: np.ndarray) -> List[Image.Image]: if images.ndim == 3: images = images[None, ...] @@ -24,36 +16,15 @@ def numpy_to_pil(images: np.ndarray) -> List[Image.Image]: def check_image(x_image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: global safety_feature_extractor, safety_checker - - # if safety_feature_extractor is None: - # safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - # safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) - - # safety_checker_input = safety_feature_extractor( - # images=numpy_to_pil(x_image), return_tensors="pt" - # ) - # x_checked_image, hs = safety_checker( - # images=x_image, clip_input=safety_checker_input.pixel_values - # ) - - # return x_checked_image, hs return x_image, False def check_batch(x: torch.Tensor) -> torch.Tensor: x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() - # x_checked_image, _ = check_image(x_samples_ddim_numpy) x_checked_image = x_samples_ddim_numpy x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) return x def convert_to_sd(img: Image) -> Image: - # _, hs = check_image(np.array(img)) - # if any(hs): - # img = ( - # img.resize((int(img.width * 0.1), int(img.height * 0.1))) - # .resize(img.size, Image.BOX) - # .filter(ImageFilter.BLUR) - # ) return img