sd-webui-reactor/scripts/reactor_swapper.py
2023-11-19 19:09:20 -05:00

564 lines
21 KiB
Python

import copy
import os
from dataclasses import dataclass
from typing import List, Union
import cv2
import numpy as np
from PIL import Image
import insightface
from scripts.reactor_helpers import get_image_md5hash, get_Device
from modules.face_restoration import FaceRestoration
try: # A1111
from modules import codeformer_model
except: # SD.Next
from modules.postprocess import codeformer_model
from modules.upscaler import UpscalerData
from modules.shared import state
from scripts.reactor_logger import logger
try:
from modules.paths_internal import models_path
except:
try:
from modules.paths import models_path
except:
model_path = os.path.abspath("models")
import warnings
PILImage = Image.Image
CV2ImgU8 = np.ndarray[int, np.dtype[uint8]]
Face = IFace
BoxCoords = Tuple[int, int, int, int]
class Gender(Enum):
AUTO = -1
FEMALE = 0
MALE = 1
np.warnings = warnings
np.warnings.filterwarnings('ignore')
DEVICE = get_Device()
if DEVICE == "CUDA":
PROVIDERS = ["CUDAExecutionProvider"]
else:
PROVIDERS = ["CPUExecutionProvider"]
@dataclass
class EnhancementOptions:
do_restore_first: bool = True
scale: int = 1
upscaler: UpscalerData = None
upscale_visibility: float = 0.5
face_restorer: FaceRestoration = None
restorer_visibility: float = 0.5
codeformer_weight: float = 0.5
MESSAGED_STOPPED = False
MESSAGED_SKIPPED = False
def reset_messaged():
global MESSAGED_STOPPED, MESSAGED_SKIPPED
if not state.interrupted:
MESSAGED_STOPPED = False
if not state.skipped:
MESSAGED_SKIPPED = False
def check_process_halt(msgforced: bool = False):
global MESSAGED_STOPPED, MESSAGED_SKIPPED
if state.interrupted:
if not MESSAGED_STOPPED or msgforced:
logger.status("Stopped by User")
MESSAGED_STOPPED = True
return True
if state.skipped:
if not MESSAGED_SKIPPED or msgforced:
logger.status("Skipped by User")
MESSAGED_SKIPPED = True
return True
return False
FS_MODEL = None
CURRENT_FS_MODEL_PATH = None
ANALYSIS_MODEL = None
SOURCE_FACES = None
SOURCE_IMAGE_HASH = None
TARGET_FACES = None
TARGET_IMAGE_HASH = None
def getAnalysisModel():
global ANALYSIS_MODEL
if ANALYSIS_MODEL is None:
ANALYSIS_MODEL = insightface.app.FaceAnalysis(
name="buffalo_l", providers=PROVIDERS, root=os.path.join(models_path, "insightface") # note: allowed_modules=['detection', 'genderage']
)
return ANALYSIS_MODEL
def getFaceSwapModel(model_path: str):
global FS_MODEL
global CURRENT_FS_MODEL_PATH
if CURRENT_FS_MODEL_PATH is None or CURRENT_FS_MODEL_PATH != model_path:
CURRENT_FS_MODEL_PATH = model_path
FS_MODEL = insightface.model_zoo.get_model(model_path, providers=PROVIDERS)
return FS_MODEL
def restore_face(image: Image, enhancement_options: EnhancementOptions):
result_image = image
if check_process_halt(msgforced=True):
return result_image
if enhancement_options.face_restorer is not None:
original_image = result_image.copy()
logger.status("Restoring the face with %s", enhancement_options.face_restorer.name())
numpy_image = np.array(result_image)
if enhancement_options.face_restorer.name() == "CodeFormer":
numpy_image = codeformer_model.codeformer.restore(
numpy_image, w=enhancement_options.codeformer_weight
)
else:
numpy_image = enhancement_options.face_restorer.restore(numpy_image)
restored_image = Image.fromarray(numpy_image)
result_image = Image.blend(
original_image, restored_image, enhancement_options.restorer_visibility
)
return result_image
def upscale_image(image: Image, enhancement_options: EnhancementOptions):
result_image = image
if check_process_halt(msgforced=True):
return result_image
if enhancement_options.upscaler is not None and enhancement_options.upscaler.name != "None":
original_image = result_image.copy()
logger.status(
"Upscaling with %s scale = %s",
enhancement_options.upscaler.name,
enhancement_options.scale,
)
result_image = enhancement_options.upscaler.scaler.upscale(
original_image, enhancement_options.scale, enhancement_options.upscaler.data_path
)
if enhancement_options.scale == 1:
result_image = Image.blend(
original_image, result_image, enhancement_options.upscale_visibility
)
return result_image
def enhance_image(image: Image, enhancement_options: EnhancementOptions):
result_image = image
if check_process_halt(msgforced=True):
return result_image
if enhancement_options.do_restore_first:
result_image = restore_face(result_image, enhancement_options)
result_image = upscale_image(result_image, enhancement_options)
else:
result_image = upscale_image(result_image, enhancement_options)
result_image = restore_face(result_image, enhancement_options)
return result_image
def get_gender(face, face_index):
gender = [
x.sex
for x in face
]
gender.reverse()
try:
face_gender = gender[face_index]
except:
logger.error("Gender Detection: No face with index = %s was found", face_index)
return "None"
return face_gender
def get_face_gender(
face,
face_index,
gender_condition,
operated: str,
gender_detected,
):
face_gender = gender_detected
if face_gender == "None":
return None, 0
logger.status("%s Face %s: Detected Gender -%s-", operated, face_index, face_gender)
if (gender_condition == 1 and face_gender == "F") or (gender_condition == 2 and face_gender == "M"):
logger.status("OK - Detected Gender matches Condition")
try:
return sorted(face, key=lambda x: x.bbox[0])[face_index], 0
except IndexError:
return None, 0
else:
logger.status("WRONG - Detected Gender doesn't match Condition")
return sorted(face, key=lambda x: x.bbox[0])[face_index], 1
def get_face_age(face, face_index):
age = [
x.age
for x in face
]
age.reverse()
try:
face_age = age[face_index]
except:
logger.error("Age Detection: No face with index = %s was found", face_index)
return "None"
return face_age
def half_det_size(det_size):
logger.status("Trying to halve 'det_size' parameter")
return (det_size[0] // 2, det_size[1] // 2)
def analyze_faces(img_data: np.ndarray, det_size=(640, 640)):
logger.info("Applied Execution Provider: %s", PROVIDERS[0])
face_analyser = copy.deepcopy(getAnalysisModel())
face_analyser.prepare(ctx_id=0, det_size=det_size)
return face_analyser.get(img_data)
def get_face_single(img_data: np.ndarray, face, face_index=0, det_size=(640, 640), gender_source=0, gender_target=0):
buffalo_path = os.path.join(models_path, "insightface/models/buffalo_l.zip")
if os.path.exists(buffalo_path):
os.remove(buffalo_path)
face_age = "None"
try:
face_age = get_face_age(face, face_index)
except:
logger.error("Cannot detect any Age for Face index = %s", face_index)
face_gender = "None"
try:
face_gender = get_gender(face, face_index)
gender_detected = face_gender
face_gender = "Female" if face_gender == "F" else ("Male" if face_gender == "M" else "None")
except:
logger.error("Cannot detect any Gender for Face index = %s", face_index)
if gender_source != 0:
if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
det_size_half = half_det_size(det_size)
return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
faces, wrong_gender = get_face_gender(face,face_index,gender_source,"Source",gender_detected)
return faces, wrong_gender, face_age, face_gender
if gender_target != 0:
if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
det_size_half = half_det_size(det_size)
return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
faces, wrong_gender = get_face_gender(face,face_index,gender_target,"Target",gender_detected)
return faces, wrong_gender, face_age, face_gender
if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
det_size_half = half_det_size(det_size)
return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
try:
return sorted(face, key=lambda x: x.bbox[0])[face_index], 0, face_age, face_gender
except IndexError:
return None, 0, face_age, face_gender
def swap_face(
source_img: Image.Image,
target_img: Image.Image,
model: Union[str, None] = None,
source_faces_index: List[int] = [0],
faces_index: List[int] = [0],
enhancement_options: Union[EnhancementOptions, None] = None,
gender_source: int = 0,
gender_target: int = 0,
source_hash_check: bool = True,
target_hash_check: bool = False,
device: str = "CPU",
):
global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH, PROVIDERS
result_image = target_img
PROVIDERS = ["CUDAExecutionProvider"] if device == "CUDA" else ["CPUExecutionProvider"]
if check_process_halt():
return result_image, [], 0
if model is not None:
if isinstance(source_img, str): # source_img is a base64 string
import base64, io
if 'base64,' in source_img: # check if the base64 string has a data URL scheme
# split the base64 string to get the actual base64 encoded image data
base64_data = source_img.split('base64,')[-1]
# decode base64 string to bytes
img_bytes = base64.b64decode(base64_data)
else:
# if no data URL scheme, just decode
img_bytes = base64.b64decode(source_img)
source_img = Image.open(io.BytesIO(img_bytes))
source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
output: List = []
output_info: str = ""
swapped = 0
if source_hash_check:
source_image_md5hash = get_image_md5hash(source_img)
if SOURCE_IMAGE_HASH is None:
SOURCE_IMAGE_HASH = source_image_md5hash
source_image_same = False
else:
source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False
if not source_image_same:
SOURCE_IMAGE_HASH = source_image_md5hash
logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH)
logger.info("Source Image the Same? %s", source_image_same)
if SOURCE_FACES is None or not source_image_same:
logger.status("Analyzing Source Image...")
source_faces = analyze_faces(source_img)
SOURCE_FACES = source_faces
elif source_image_same:
logger.status("Using Ready Source Face(s) Model...")
source_faces = SOURCE_FACES
else:
logger.status("Analyzing Source Image...")
source_faces = analyze_faces(source_img)
if source_faces is not None:
if target_hash_check:
target_image_md5hash = get_image_md5hash(target_img)
if TARGET_IMAGE_HASH is None:
TARGET_IMAGE_HASH = target_image_md5hash
target_image_same = False
else:
target_image_same = True if TARGET_IMAGE_HASH == target_image_md5hash else False
if not target_image_same:
TARGET_IMAGE_HASH = target_image_md5hash
logger.info("Target Image MD5 Hash = %s", TARGET_IMAGE_HASH)
logger.info("Target Image the Same? %s", target_image_same)
if TARGET_FACES is None or not target_image_same:
logger.status("Analyzing Target Image...")
target_faces = analyze_faces(target_img)
TARGET_FACES = target_faces
elif target_image_same:
logger.status("Using Ready Target Face(s) Model...")
target_faces = TARGET_FACES
else:
logger.status("Analyzing Target Image...")
target_faces = analyze_faces(target_img)
logger.status("Detecting Source Face, Index = %s", source_faces_index[0])
source_face, wrong_gender, source_age, source_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source)
if source_age != "None" or source_gender != "None":
logger.status("Detected: -%s- y.o. %s", source_age, source_gender)
output_info = f"SourceFaceIndex={source_faces_index[0]};Age={source_age};Gender={source_gender}\n"
output.append(output_info)
if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index):
logger.status("Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.")
elif source_face is not None:
result = target_img
face_swapper = getFaceSwapModel(model)
source_face_idx = 0
for face_num in faces_index:
if check_process_halt():
return result_image, [], 0
if len(source_faces_index) > 1 and source_face_idx > 0:
logger.status("Detecting Source Face, Index = %s", source_faces_index[source_face_idx])
source_face, wrong_gender, source_age, source_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source)
if source_age != "None" or source_gender != "None":
logger.status("Detected: -%s- y.o. %s", source_age, source_gender)
output_info = f"SourceFaceIndex={source_faces_index[source_face_idx]};Age={source_age};Gender={source_gender}\n"
output.append(output_info)
source_face_idx += 1
if source_face is not None and wrong_gender == 0:
logger.status("Detecting Target Face, Index = %s", face_num)
target_face, wrong_gender, target_age, target_gender = get_face_single(target_img, target_faces, face_index=face_num, gender_target=gender_target)
if target_age != "None" or target_gender != "None":
logger.status("Detected: -%s- y.o. %s", target_age, target_gender)
output_info = f"TargetFaceIndex={face_num};Age={target_age};Gender={target_gender}\n"
output.append(output_info)
if target_face is not None and wrong_gender == 0:
logger.status("Swapping Source into Target")
result = face_swapper.get(result, target_face, source_face)
swapped += 1
elif wrong_gender == 1:
wrong_gender = 0
if source_face_idx == len(source_faces_index):
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
if enhancement_options is not None and len(source_faces_index) > 1:
result_image = enhance_image(result_image, enhancement_options)
return result_image, output, swapped
else:
logger.status(f"No target face found for {face_num}")
elif wrong_gender == 1:
wrong_gender = 0
if source_face_idx == len(source_faces_index):
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
if enhancement_options is not None and len(source_faces_index) > 1:
result_image = enhance_image(result_image, enhancement_options)
return result_image, output, swapped
else:
logger.status(f"No source face found for face number {source_face_idx}.")
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
if enhancement_options is not None and swapped > 0:
result_image = enhance_image(result_image, enhancement_options)
else:
logger.status("No source face(s) in the provided Index")
else:
logger.status("No source face(s) found")
return result_image, output, swapped
def merge_images_with_mask(
image1: CV2ImgU8, image2: CV2ImgU8, mask: CV2ImgU8
) -> CV2ImgU8:
"""
Merges two images using a given mask. The regions where the mask is set will be replaced with the corresponding
areas of the second image.
Args:
image1 (CV2Img): The base image, which must have the same shape as image2.
image2 (CV2Img): The image to be merged, which must have the same shape as image1.
mask (CV2Img): A binary mask specifying the regions to be merged. The mask shape should match image1's first two dimensions.
Returns:
CV2Img: The merged image.
Raises:
ValueError: If the shapes of the images and mask do not match.
"""
if image1.shape != image2.shape or image1.shape[:2] != mask.shape:
raise ValueError("Img should have the same shape")
mask = mask.astype(np.uint8)
masked_region = cv2.bitwise_and(image2, image2, mask=mask)
inverse_mask = cv2.bitwise_not(mask)
empty_region = cv2.bitwise_and(image1, image1, mask=inverse_mask)
merged_image = cv2.add(empty_region, masked_region)
return merged_image
def erode_mask(mask: CV2ImgU8, kernel_size: int = 3, iterations: int = 1) -> CV2ImgU8:
"""
Erodes a binary mask using a given kernel size and number of iterations.
Args:
mask (CV2Img): The binary mask to erode.
kernel_size (int, optional): The size of the kernel. Default is 3.
iterations (int, optional): The number of erosion iterations. Default is 1.
Returns:
CV2Img: The eroded mask.
"""
kernel = np.ones((kernel_size, kernel_size), np.uint8)
eroded_mask = cv2.erode(mask, kernel, iterations=iterations)
return eroded_mask
def apply_gaussian_blur(
mask: CV2ImgU8, kernel_size: Tuple[int, int] = (5, 5), sigma_x: int = 0
) -> CV2ImgU8:
"""
Applies a Gaussian blur to a mask.
Args:
mask (CV2Img): The mask to blur.
kernel_size (tuple, optional): The size of the kernel, e.g. (5, 5). Default is (5, 5).
sigma_x (int, optional): The standard deviation in the X direction. Default is 0.
Returns:
CV2Img: The blurred mask.
"""
blurred_mask = cv2.GaussianBlur(mask, kernel_size, sigma_x)
return blurred_mask
def dilate_mask(mask: CV2ImgU8, kernel_size: int = 5, iterations: int = 1) -> CV2ImgU8:
"""
Dilates a binary mask using a given kernel size and number of iterations.
Args:
mask (CV2Img): The binary mask to dilate.
kernel_size (int, optional): The size of the kernel. Default is 5.
iterations (int, optional): The number of dilation iterations. Default is 1.
Returns:
CV2Img: The dilated mask.
"""
kernel = np.ones((kernel_size, kernel_size), np.uint8)
dilated_mask = cv2.dilate(mask, kernel, iterations=iterations)
return dilated_mask
def get_face_mask(aimg: CV2ImgU8, bgr_fake: CV2ImgU8) -> CV2ImgU8:
"""
Generates a face mask by performing bitwise OR on two face masks and then dilating the result.
Args:
aimg (CV2Img): Input image for generating the first face mask.
bgr_fake (CV2Img): Input image for generating the second face mask.
Returns:
CV2Img: The combined and dilated face mask.
"""
mask1 = generate_face_mask(aimg, device=shared.device)
mask2 = generate_face_mask(bgr_fake, device=shared.device)
mask = dilate_mask(cv2.bitwise_or(mask1, mask2))
return mask