sd-webui-reactor/scripts/reactor_inferencers/bisenet_mask_generator.py
Gourieff 70479479b3 HotFIX: Compatibility with "sd-face-editor"
Issue #211
Because Reactor's Mask Correction is based on sd-face-editor masking script - we need another naming of active folders

+VersionUP (0.5.1 beta3)
2023-11-24 23:02:00 +07:00

87 lines
3.1 KiB
Python

from typing import List, Tuple
import cv2
import modules.shared as shared
import numpy as np
import torch
from facexlib.parsing import init_parsing_model
from facexlib.utils.misc import img2tensor
from torchvision.transforms.functional import normalize
from scripts.reactor_inferencers.mask_generator import MaskGenerator
class BiSeNetMaskGenerator(MaskGenerator):
def __init__(self) -> None:
self.mask_model = init_parsing_model(device=shared.device)
def name(self):
return "BiSeNet"
def generate_mask(
self,
face_image: np.ndarray,
face_area_on_image: Tuple[int, int, int, int],
affected_areas: List[str],
mask_size: int,
use_minimal_area: bool,
fallback_ratio: float = 0.25,
**kwargs,
) -> np.ndarray:
# original_face_image = face_image
face_image = face_image.copy()
face_image = face_image[:, :, ::-1]
if use_minimal_area:
face_image = MaskGenerator.mask_non_face_areas(face_image, face_area_on_image)
h, w, _ = face_image.shape
if w != 512 or h != 512:
rw = (int(w * (512 / w)) // 8) * 8
rh = (int(h * (512 / h)) // 8) * 8
face_image = cv2.resize(face_image, dsize=(rw, rh))
face_tensor = img2tensor(face_image.astype("float32") / 255.0, float32=True)
normalize(face_tensor, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
face_tensor = torch.unsqueeze(face_tensor, 0).to(shared.device)
with torch.no_grad():
face = self.mask_model(face_tensor)[0]
face = face.squeeze(0).cpu().numpy().argmax(0)
face = face.copy().astype(np.uint8)
mask = self.__to_mask(face, affected_areas)
if mask_size > 0:
mask = cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=mask_size)
if w != 512 or h != 512:
mask = cv2.resize(mask, dsize=(w, h))
# """if MaskGenerator.calculate_mask_coverage(mask) < fallback_ratio:
# logger.info("Use fallback mask generator")
# mask = self.fallback_mask_generator.generate_mask(
# original_face_image, face_area_on_image, use_minimal_area=True
# )"""
return mask
def __to_mask(self, face: np.ndarray, affected_areas: List[str]) -> np.ndarray:
keep_face = "Face" in affected_areas
keep_neck = "Neck" in affected_areas
keep_hair = "Hair" in affected_areas
keep_hat = "Hat" in affected_areas
mask = np.zeros((face.shape[0], face.shape[1], 3), dtype=np.uint8)
num_of_class = np.max(face)
for i in range(1, num_of_class + 1):
index = np.where(face == i)
if i < 14 and keep_face:
mask[index[0], index[1], :] = [255, 255, 255]
elif i == 14 and keep_neck:
mask[index[0], index[1], :] = [255, 255, 255]
elif i == 17 and keep_hair:
mask[index[0], index[1], :] = [255, 255, 255]
elif i == 18 and keep_hat:
mask[index[0], index[1], :] = [255, 255, 255]
return mask