NSFW filter OFF

This commit is contained in:
Art Gourieff 2023-06-19 02:52:22 +07:00 committed by GitHub
parent ef79abd735
commit eda6e9f624
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,57 +1,59 @@
from typing import List, Union, Dict, Set, Tuple from typing import List, Union, Dict, Set, Tuple
from diffusers.pipelines.stable_diffusion.safety_checker import ( # from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker, # StableDiffusionSafetyChecker,
) # )
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
import torch import torch
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
import numpy as np import numpy as np
safety_model_id: str = "CompVis/stable-diffusion-safety-checker" # safety_model_id: str = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor: AutoFeatureExtractor = None # safety_feature_extractor: AutoFeatureExtractor = None
safety_checker: StableDiffusionSafetyChecker = None # safety_checker: StableDiffusionSafetyChecker = None
def numpy_to_pil(images: np.ndarray) -> List[Image.Image]: def numpy_to_pil(images: np.ndarray) -> List[Image.Image]:
if images.ndim == 3: if images.ndim == 3:
images = images[None, ...] images = images[None, ...]
images = (images * 255).round().astype("uint8") images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images] pil_images = [Image.fromarray(image) for image in images]
return pil_images return pil_images
def check_image(x_image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: def check_image(x_image: np.ndarray) -> Tuple[np.ndarray, List[bool]]:
global safety_feature_extractor, safety_checker global safety_feature_extractor, safety_checker
if safety_feature_extractor is None: # if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) # safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) # safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_checker_input = safety_feature_extractor( # safety_checker_input = safety_feature_extractor(
images=numpy_to_pil(x_image), return_tensors="pt" # images=numpy_to_pil(x_image), return_tensors="pt"
) # )
x_checked_image, hs = safety_checker( # x_checked_image, hs = safety_checker(
images=x_image, clip_input=safety_checker_input.pixel_values # images=x_image, clip_input=safety_checker_input.pixel_values
) # )
return x_checked_image, hs # return x_checked_image, hs
return x_image, False
def check_batch(x: torch.Tensor) -> torch.Tensor:
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() def check_batch(x: torch.Tensor) -> torch.Tensor:
x_checked_image, _ = check_image(x_samples_ddim_numpy) x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) # x_checked_image, _ = check_image(x_samples_ddim_numpy)
return x x_checked_image = x_samples_ddim_numpy
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
return x
def convert_to_sd(img: Image) -> Image:
_, hs = check_image(np.array(img))
if any(hs): def convert_to_sd(img: Image) -> Image:
img = ( # _, hs = check_image(np.array(img))
img.resize((int(img.width * 0.1), int(img.height * 0.1))) # if any(hs):
.resize(img.size, Image.BOX) # img = (
.filter(ImageFilter.BLUR) # img.resize((int(img.width * 0.1), int(img.height * 0.1)))
) # .resize(img.size, Image.BOX)
return img # .filter(ImageFilter.BLUR)
# )
return img