diff --git a/scripts/entities/face.py b/scripts/entities/face.py new file mode 100644 index 0000000..21a6167 --- /dev/null +++ b/scripts/entities/face.py @@ -0,0 +1,147 @@ +import traceback + +import cv2 +import numpy as np +from modules import images +from PIL import Image + + +from scripts.entities.rect import Point, Rect + + +class Face: + def __init__(self, entire_image: np.ndarray, face_area: Rect, face_margin: float, face_size: int, upscaler: str): + self.face_area = face_area + self.center = face_area.center + left, top, right, bottom = face_area.to_square() + + self.left, self.top, self.right, self.bottom = self.__ensure_margin( + left, top, right, bottom, entire_image, face_margin + ) + + self.width = self.right - self.left + self.height = self.bottom - self.top + + self.image = self.__crop_face_image(entire_image, face_size, upscaler) + self.face_size = face_size + self.scale_factor = face_size / self.width + self.face_area_on_image = self.__get_face_area_on_image() + self.landmarks_on_image = self.__get_landmarks_on_image() + + def __get_face_area_on_image(self): + left = int((self.face_area.left - self.left) * self.scale_factor) + top = int((self.face_area.top - self.top) * self.scale_factor) + right = int((self.face_area.right - self.left) * self.scale_factor) + bottom = int((self.face_area.bottom - self.top) * self.scale_factor) + return self.__clip_values(left, top, right, bottom) + + def __get_landmarks_on_image(self): + landmarks = [] + if self.face_area.landmarks is not None: + for landmark in self.face_area.landmarks: + landmarks.append( + Point( + int((landmark.x - self.left) * self.scale_factor), + int((landmark.y - self.top) * self.scale_factor), + ) + ) + return landmarks + + def __crop_face_image(self, entire_image: np.ndarray, face_size: int, upscaler: str): + cropped = entire_image[self.top : self.bottom, self.left : self.right, :] + if upscaler: + return images.resize_image(0, Image.fromarray(cropped), face_size, face_size, upscaler) + else: + return Image.fromarray(cv2.resize(cropped, dsize=(face_size, face_size))) + + def __ensure_margin(self, left: int, top: int, right: int, bottom: int, entire_image: np.ndarray, margin: float): + entire_height, entire_width = entire_image.shape[:2] + + side_length = right - left + margin = min(min(entire_height, entire_width) / side_length, margin) + diff = int((side_length * margin - side_length) / 2) + + top = top - diff + bottom = bottom + diff + left = left - diff + right = right + diff + + if top < 0: + bottom = bottom - top + top = 0 + if left < 0: + right = right - left + left = 0 + + if bottom > entire_height: + top = top - (bottom - entire_height) + bottom = entire_height + if right > entire_width: + left = left - (right - entire_width) + right = entire_width + + return left, top, right, bottom + + def get_angle(self) -> float: + landmarks = getattr(self.face_area, "landmarks", None) + if landmarks is None: + return 0 + + eye1 = getattr(landmarks, "eye1", None) + eye2 = getattr(landmarks, "eye2", None) + if eye2 is None or eye1 is None: + return 0 + + try: + dx = eye2.x - eye1.x + dy = eye2.y - eye1.y + if dx == 0: + dx = 1 + angle = np.arctan(dy / dx) * 180 / np.pi + + if dx < 0: + angle = (angle + 180) % 360 + return angle + except Exception: + print(traceback.format_exc()) + return 0 + + def rotate_face_area_on_image(self, angle: float): + center = [ + (self.face_area_on_image[0] + self.face_area_on_image[2]) / 2, + (self.face_area_on_image[1] + self.face_area_on_image[3]) / 2, + ] + + points = [ + [self.face_area_on_image[0], self.face_area_on_image[1]], + [self.face_area_on_image[2], self.face_area_on_image[3]], + ] + + angle = np.radians(angle) + rot_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) + + points = np.array(points) - center + points = np.dot(points, rot_matrix.T) + points += center + left, top, right, bottom = (int(points[0][0]), int(points[0][1]), int(points[1][0]), int(points[1][1])) + + left, right = (right, left) if left > right else (left, right) + top, bottom = (bottom, top) if top > bottom else (top, bottom) + + width, height = right - left, bottom - top + if width < height: + left, right = left - (height - width) // 2, right + (height - width) // 2 + elif height < width: + top, bottom = top - (width - height) // 2, bottom + (width - height) // 2 + return self.__clip_values(left, top, right, bottom) + + def __clip_values(self, *args): + result = [] + for val in args: + if val < 0: + result.append(0) + elif val > self.face_size: + result.append(self.face_size) + else: + result.append(val) + return tuple(result) diff --git a/scripts/entities/rect.py b/scripts/entities/rect.py new file mode 100644 index 0000000..424b555 --- /dev/null +++ b/scripts/entities/rect.py @@ -0,0 +1,78 @@ +from typing import Dict, NamedTuple, Tuple + +import numpy as np + + +class Point(NamedTuple): + x: int + y: int + + +class Landmarks(NamedTuple): + eye1: Point + eye2: Point + nose: Point + mouth1: Point + mouth2: Point + + +class Rect: + def __init__( + self, + left: int, + top: int, + right: int, + bottom: int, + tag: str = "face", + landmarks: Landmarks = None, + attributes: Dict[str, str] = {}, + ) -> None: + self.tag = tag + self.left = left + self.top = top + self.right = right + self.bottom = bottom + self.center = int((right + left) / 2) + self.middle = int((top + bottom) / 2) + self.width = right - left + self.height = bottom - top + self.size = self.width * self.height + self.landmarks = landmarks + self.attributes = attributes + + @classmethod + def from_ndarray( + cls, + face_box: np.ndarray, + tag: str = "face", + landmarks: Landmarks = None, + attributes: Dict[str, str] = {}, + ) -> "Rect": + left, top, right, bottom, *_ = list(map(int, face_box)) + return cls(left, top, right, bottom, tag, landmarks, attributes) + + def to_tuple(self) -> Tuple[int, int, int, int]: + return self.left, self.top, self.right, self.bottom + + def to_square(self): + left, top, right, bottom = self.to_tuple() + + width = right - left + height = bottom - top + + if width % 2 == 1: + right = right + 1 + width = width + 1 + if height % 2 == 1: + bottom = bottom + 1 + height = height + 1 + + diff = int(abs(width - height) / 2) + if width > height: + top = top - diff + bottom = bottom + diff + else: + left = left - diff + right = right + diff + + return left, top, right, bottom diff --git a/scripts/inferencers/bisenet_mask_generator.py b/scripts/inferencers/bisenet_mask_generator.py new file mode 100644 index 0000000..27eb2e7 --- /dev/null +++ b/scripts/inferencers/bisenet_mask_generator.py @@ -0,0 +1,88 @@ +from typing import List, Tuple + +import cv2 +import modules.shared as shared +import numpy as np +import torch +from facexlib.parsing import init_parsing_model +from facexlib.utils.misc import img2tensor +from torchvision.transforms.functional import normalize +from PIL import Image +from scripts.inferencers.mask_generator import MaskGenerator +from scripts.reactor_logger import logger + +class BiSeNetMaskGenerator(MaskGenerator): + def __init__(self) -> None: + self.mask_model = init_parsing_model(device=shared.device) + + def name(self): + return "BiSeNet" + + def generate_mask( + self, + face_image: np.ndarray, + face_area_on_image: Tuple[int, int, int, int], + affected_areas: List[str], + mask_size: int, + use_minimal_area: bool, + fallback_ratio: float = 0.25, + **kwargs, + ) -> np.ndarray: + original_face_image = face_image + face_image = face_image.copy() + face_image = face_image[:, :, ::-1] + + if use_minimal_area: + face_image = MaskGenerator.mask_non_face_areas(face_image, face_area_on_image) + + h, w, _ = face_image.shape + + if w != 512 or h != 512: + rw = (int(w * (512 / w)) // 8) * 8 + rh = (int(h * (512 / h)) // 8) * 8 + face_image = cv2.resize(face_image, dsize=(rw, rh)) + + face_tensor = img2tensor(face_image.astype("float32") / 255.0, float32=True) + normalize(face_tensor, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_tensor = torch.unsqueeze(face_tensor, 0).to(shared.device) + + with torch.no_grad(): + face = self.mask_model(face_tensor)[0] + face = face.squeeze(0).cpu().numpy().argmax(0) + face = face.copy().astype(np.uint8) + + mask = self.__to_mask(face, affected_areas) + + if mask_size > 0: + mask = cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=mask_size) + + if w != 512 or h != 512: + mask = cv2.resize(mask, dsize=(w, h)) + + """if MaskGenerator.calculate_mask_coverage(mask) < fallback_ratio: + logger.info("Use fallback mask generator") + mask = self.fallback_mask_generator.generate_mask( + original_face_image, face_area_on_image, use_minimal_area=True + )""" + + return mask + + def __to_mask(self, face: np.ndarray, affected_areas: List[str]) -> np.ndarray: + keep_face = "Face" in affected_areas + keep_neck = "Neck" in affected_areas + keep_hair = "Hair" in affected_areas + keep_hat = "Hat" in affected_areas + + mask = np.zeros((face.shape[0], face.shape[1], 3), dtype=np.uint8) + num_of_class = np.max(face) + for i in range(1, num_of_class + 1): + index = np.where(face == i) + if i < 14 and keep_face: + mask[index[0], index[1], :] = [255, 255, 255] + elif i == 14 and keep_neck: + mask[index[0], index[1], :] = [255, 255, 255] + elif i == 17 and keep_hair: + mask[index[0], index[1], :] = [255, 255, 255] + elif i == 18 and keep_hat: + mask[index[0], index[1], :] = [255, 255, 255] + return mask diff --git a/scripts/inferencers/mask_generator.py b/scripts/inferencers/mask_generator.py new file mode 100644 index 0000000..9359523 --- /dev/null +++ b/scripts/inferencers/mask_generator.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import cv2 +import numpy as np + + +class MaskGenerator(ABC): + @abstractmethod + def name(self) -> str: + pass + + @abstractmethod + def generate_mask( + self, + face_image: np.ndarray, + face_area_on_image: Tuple[int, int, int, int], + **kwargs, + ) -> np.ndarray: + pass + + @staticmethod + def mask_non_face_areas(image: np.ndarray, face_area_on_image: Tuple[int, int, int, int]) -> np.ndarray: + left, top, right, bottom = face_area_on_image + image = image.copy() + image[:top, :] = 0 + image[bottom:, :] = 0 + image[:, :left] = 0 + image[:, right:] = 0 + return image + + @staticmethod + def calculate_mask_coverage(mask: np.ndarray): + gray_mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) + non_black_pixels = np.count_nonzero(gray_mask) + total_pixels = gray_mask.size + return non_black_pixels / total_pixels diff --git a/scripts/reactor_faceswap.py b/scripts/reactor_faceswap.py index 4499b15..0c54371 100644 --- a/scripts/reactor_faceswap.py +++ b/scripts/reactor_faceswap.py @@ -66,6 +66,8 @@ class FaceSwapScript(scripts.Script): img = gr.Image(type="pil") enable = gr.Checkbox(False, label="Enable", info=f"The Fast and Simple FaceSwap Extension - {version_flag}") save_original = gr.Checkbox(False, label="Save Original", info="Save the original image(s) made before swapping; If you use \"img2img\" - this option will affect with \"Swap in generated\" only") + mask_face = gr.Checkbox(False, label="Mask Faces", info="Attempt to mask only the faces and eliminate pixelation of the image around the contours.") + gr.Markdown("
") gr.Markdown("Source Image (above):") with gr.Row(): @@ -211,6 +213,7 @@ class FaceSwapScript(scripts.Script): source_hash_check, target_hash_check, device, + mask_face ] @@ -264,6 +267,7 @@ class FaceSwapScript(scripts.Script): source_hash_check, target_hash_check, device, + mask_face ): self.enable = enable if self.enable: @@ -291,6 +295,7 @@ class FaceSwapScript(scripts.Script): self.source_hash_check = source_hash_check self.target_hash_check = target_hash_check self.device = device + self.mask_face = mask_face if self.gender_source is None or self.gender_source == "No": self.gender_source = 0 if self.gender_target is None or self.gender_target == "No": @@ -334,6 +339,7 @@ class FaceSwapScript(scripts.Script): source_hash_check=self.source_hash_check, target_hash_check=self.target_hash_check, device=self.device, + mask_face=mask_face ) p.init_images[i] = result # result_path = get_image_path(p.init_images[i], p.outpath_samples, "", p.all_seeds[i], p.all_prompts[i], "txt", p=p, suffix="-swapped") @@ -385,6 +391,7 @@ class FaceSwapScript(scripts.Script): source_hash_check=self.source_hash_check, target_hash_check=self.target_hash_check, device=self.device, + mask_face=self.mask_face ) if result is not None and swapped > 0: result_images.append(result) @@ -442,6 +449,7 @@ class FaceSwapScript(scripts.Script): source_hash_check=self.source_hash_check, target_hash_check=self.target_hash_check, device=self.device, + mask_face=self.mask_face ) try: pp = scripts_postprocessing.PostprocessedImage(result) @@ -468,6 +476,8 @@ class FaceSwapScriptExtras(scripts_postprocessing.ScriptPostprocessing): with gr.Column(): img = gr.Image(type="pil") enable = gr.Checkbox(False, label="Enable", info=f"The Fast and Simple FaceSwap Extension - {version_flag}") + mask_face = gr.Checkbox(False, label="Mask Faces", info="Attempt to mask only the faces and eliminate pixelation of the image around the contours.") + gr.Markdown("Source Image (above):") with gr.Row(): source_faces_index = gr.Textbox( @@ -582,6 +592,7 @@ class FaceSwapScriptExtras(scripts_postprocessing.ScriptPostprocessing): 'gender_target': gender_target, 'codeformer_weight': codeformer_weight, 'device': device, + 'mask_face':mask_face } return args @@ -631,6 +642,7 @@ class FaceSwapScriptExtras(scripts_postprocessing.ScriptPostprocessing): self.gender_target = args['gender_target'] self.codeformer_weight = args['codeformer_weight'] self.device = args['device'] + self.mask_face = args['mask_face'] if self.gender_source is None or self.gender_source == "No": self.gender_source = 0 if self.gender_target is None or self.gender_target == "No": @@ -669,6 +681,7 @@ class FaceSwapScriptExtras(scripts_postprocessing.ScriptPostprocessing): source_hash_check=True, target_hash_check=True, device=self.device, + mask_face=self.mask_face ) try: pp.info["ReActor"] = True diff --git a/scripts/reactor_swapper.py b/scripts/reactor_swapper.py index c35da95..9fe93e9 100644 --- a/scripts/reactor_swapper.py +++ b/scripts/reactor_swapper.py @@ -2,13 +2,16 @@ import copy import os from dataclasses import dataclass from typing import List, Union -from typing import Tuple + import cv2 import numpy as np -from PIL import Image -from insightface.app.common import Face as IFace +from numpy import uint8 +from PIL import Image, ImageDraw +from scripts.inferencers.bisenet_mask_generator import BiSeNetMaskGenerator +from scripts.entities.face import Face +from scripts.entities.rect import Rect import insightface - +from torchvision.transforms.functional import to_pil_image from scripts.reactor_helpers import get_image_md5hash, get_Device from modules.face_restoration import FaceRestoration try: # A1111 @@ -18,6 +21,7 @@ except: # SD.Next from modules.upscaler import UpscalerData from modules.shared import state from scripts.reactor_logger import logger + try: from modules.paths_internal import models_path except: @@ -28,16 +32,6 @@ except: import warnings -PILImage = Image.Image -CV2ImgU8 = np.ndarray[int, np.dtype[uint8]] -Face = IFace -BoxCoords = Tuple[int, int, int, int] - - -class Gender(Enum): - AUTO = -1 - FEMALE = 0 - MALE = 1 np.warnings = warnings np.warnings.filterwarnings('ignore') @@ -86,8 +80,9 @@ def check_process_halt(msgforced: bool = False): FS_MODEL = None +MASK_MODEL = None CURRENT_FS_MODEL_PATH = None - +CURRENT_MASK_MODEL_PATH = None ANALYSIS_MODEL = None SOURCE_FACES = None @@ -115,6 +110,8 @@ def getFaceSwapModel(model_path: str): return FS_MODEL + + def restore_face(image: Image, enhancement_options: EnhancementOptions): result_image = image @@ -178,7 +175,28 @@ def enhance_image(image: Image, enhancement_options: EnhancementOptions): result_image = restore_face(result_image, enhancement_options) return result_image +def enhance_image_and_mask(image: Image.Image, enhancement_options: EnhancementOptions,target_img_orig:Image.Image,entire_mask_image:Image.Image)->Image.Image: + result_image = image + + if check_process_halt(msgforced=True): + return result_image + + if enhancement_options.do_restore_first: + + result_image = restore_face(result_image, enhancement_options) + result_image = Image.composite(result_image,target_img_orig,entire_mask_image) + result_image = upscale_image(result_image, enhancement_options) + else: + + result_image = upscale_image(result_image, enhancement_options) + entire_mask_image = Image.fromarray(cv2.resize(np.array(entire_mask_image),result_image.size, interpolation=cv2.INTER_AREA)).convert("L") + result_image = Image.composite(result_image,target_img_orig,entire_mask_image) + result_image = restore_face(result_image, enhancement_options) + + return result_image + + def get_gender(face, face_index): gender = [ x.sex @@ -292,6 +310,7 @@ def swap_face( source_hash_check: bool = True, target_hash_check: bool = False, device: str = "CPU", + mask_face:bool = False ): global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH, PROVIDERS result_image = target_img @@ -318,7 +337,8 @@ def swap_face( source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) - + target_img_orig = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) + entire_mask_image = np.zeros_like(np.array(target_img)) output: List = [] output_info: str = "" swapped = 0 @@ -421,7 +441,12 @@ def swap_face( if target_face is not None and wrong_gender == 0: logger.status("Swapping Source into Target") - result = face_swapper.get(result, target_face, source_face) + swapped_image = face_swapper.get(result, target_face, source_face) + + if mask_face: + result = apply_face_mask(swapped_image=swapped_image,target_image=result,target_face=target_face,entire_mask_image=entire_mask_image) + else: + result = swapped_image swapped += 1 elif wrong_gender == 1: @@ -455,8 +480,13 @@ def swap_face( result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) if enhancement_options is not None and swapped > 0: - result_image = enhance_image(result_image, enhancement_options) - + if mask_face and entire_mask_image is not None: + enhance_image_and_mask(result_image, enhancement_options,Image.fromarray(target_img_orig),Image.fromarray(entire_mask_image).convert("L")) + else: + result_image = enhance_image(result_image, enhancement_options) + elif mask_face and entire_mask_image is not None and swapped > 0: + result_image = Image.composite(result_image,Image.fromarray(target_img_orig),Image.fromarray(entire_mask_image).convert("L")) + else: logger.status("No source face(s) in the provided Index") else: @@ -465,99 +495,158 @@ def swap_face( return result_image, output, swapped -def merge_images_with_mask( - image1: CV2ImgU8, image2: CV2ImgU8, mask: CV2ImgU8 -) -> CV2ImgU8: + +def apply_face_mask(swapped_image:np.ndarray,target_image:np.ndarray,target_face,entire_mask_image:np.array)->np.ndarray: + logger.status("Masking Face") + mask_generator = BiSeNetMaskGenerator() + face = Face(target_image,Rect.from_ndarray(np.array(target_face.bbox)),1.6,512,"") + face_image = np.array(face.image) + process_face_image(face) + face_area_on_image = face.face_area_on_image + mask = mask_generator.generate_mask(face_image,face_area_on_image=face_area_on_image,affected_areas=["Face"],mask_size=0,use_minimal_area=True) + mask = cv2.blur(mask, (12, 12)) + """entire_mask_image = np.zeros_like(target_image)""" + larger_mask = cv2.resize(mask, dsize=(face.width, face.height)) + entire_mask_image[ + face.top : face.bottom, + face.left : face.right, + ] = larger_mask + + + result = Image.composite(Image.fromarray(swapped_image),Image.fromarray(target_image), Image.fromarray(entire_mask_image).convert("L")) + return np.array(result) + + +def correct_face_tilt(angle: float) -> bool: + + angle = abs(angle) + if angle > 180: + angle = 360 - angle + return angle > 40 +def _dilate(arr: np.ndarray, value: int) -> np.ndarray: + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) + return cv2.dilate(arr, kernel, iterations=1) + + +def _erode(arr: np.ndarray, value: int) -> np.ndarray: + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) + return cv2.erode(arr, kernel, iterations=1) +colors = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (255, 255, 255), + (128, 0, 0), + (0, 128, 0), + (128, 128, 0), + (0, 0, 128), + (0, 128, 128), +] + + +def color_generator(colors): + while True: + for color in colors: + yield color + + +color_iter = color_generator(colors) +def process_face_image( + face: Face, + **kwargs, + ) -> Image: + image = np.array(face.image) + overlay = image.copy() + cv2.rectangle(overlay, (0, 0), (image.shape[1], image.shape[0]), next(color_iter), -1) + l, t, r, b = face.face_area_on_image + cv2.rectangle(overlay, (l, t), (r, b), (0, 0, 0), 10) + if face.landmarks_on_image is not None: + for landmark in face.landmarks_on_image: + cv2.circle(overlay, (int(landmark.x), int(landmark.y)), 6, (0, 0, 0), 10) + alpha = 0.3 + output = cv2.addWeighted(image, 1 - alpha, overlay, alpha, 0) + + return Image.fromarray(output) +def dilate_erode(img: Image.Image, value: int) -> Image.Image: """ - Merges two images using a given mask. The regions where the mask is set will be replaced with the corresponding - areas of the second image. + The dilate_erode function takes an image and a value. + If the value is positive, it dilates the image by that amount. + If the value is negative, it erodes the image by that amount. - Args: - image1 (CV2Img): The base image, which must have the same shape as image2. - image2 (CV2Img): The image to be merged, which must have the same shape as image1. - mask (CV2Img): A binary mask specifying the regions to be merged. The mask shape should match image1's first two dimensions. + Parameters + ---------- + img: PIL.Image.Image + the image to be processed + value: int + kernel size of dilation or erosion - Returns: - CV2Img: The merged image. - - Raises: - ValueError: If the shapes of the images and mask do not match. + Returns + ------- + PIL.Image.Image + The image that has been dilated or eroded """ + if value == 0: + return img - if image1.shape != image2.shape or image1.shape[:2] != mask.shape: - raise ValueError("Img should have the same shape") - mask = mask.astype(np.uint8) - masked_region = cv2.bitwise_and(image2, image2, mask=mask) - inverse_mask = cv2.bitwise_not(mask) - empty_region = cv2.bitwise_and(image1, image1, mask=inverse_mask) - merged_image = cv2.add(empty_region, masked_region) - return merged_image + arr = np.array(img) + arr = _dilate(arr, value) if value > 0 else _erode(arr, -value) + return Image.fromarray(arr) -def erode_mask(mask: CV2ImgU8, kernel_size: int = 3, iterations: int = 1) -> CV2ImgU8: +def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]: """ - Erodes a binary mask using a given kernel size and number of iterations. + Parameters + ---------- + masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W). + The device can be CUDA, but `to_pil_image` takes care of that. - Args: - mask (CV2Img): The binary mask to erode. - kernel_size (int, optional): The size of the kernel. Default is 3. - iterations (int, optional): The number of erosion iterations. Default is 1. - - Returns: - CV2Img: The eroded mask. + shape: tuple[int, int] + (width, height) of the original image """ - kernel = np.ones((kernel_size, kernel_size), np.uint8) - eroded_mask = cv2.erode(mask, kernel, iterations=iterations) - return eroded_mask + n = masks.shape[0] + return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)] - -def apply_gaussian_blur( - mask: CV2ImgU8, kernel_size: Tuple[int, int] = (5, 5), sigma_x: int = 0 -) -> CV2ImgU8: +def create_mask_from_bbox( + bboxes: list[list[float]], shape: tuple[int, int] +) -> list[Image.Image]: """ - Applies a Gaussian blur to a mask. + Parameters + ---------- + bboxes: list[list[float]] + list of [x1, y1, x2, y2] + bounding boxes + shape: tuple[int, int] + shape of the image (width, height) - Args: - mask (CV2Img): The mask to blur. - kernel_size (tuple, optional): The size of the kernel, e.g. (5, 5). Default is (5, 5). - sigma_x (int, optional): The standard deviation in the X direction. Default is 0. + Returns + ------- + masks: list[Image.Image] + A list of masks - Returns: - CV2Img: The blurred mask. """ - blurred_mask = cv2.GaussianBlur(mask, kernel_size, sigma_x) - return blurred_mask + masks = [] + for bbox in bboxes: + mask = Image.new("L", shape, 0) + mask_draw = ImageDraw.Draw(mask) + mask_draw.rectangle(bbox, fill=255) + masks.append(mask) + return masks + +def rotate_image(image: Image, angle: float) -> Image: + if angle == 0: + return image + return Image.fromarray(rotate_array(np.array(image), angle)) -def dilate_mask(mask: CV2ImgU8, kernel_size: int = 5, iterations: int = 1) -> CV2ImgU8: - """ - Dilates a binary mask using a given kernel size and number of iterations. +def rotate_array(image: np.ndarray, angle: float) -> np.ndarray: + if angle == 0: + return image - Args: - mask (CV2Img): The binary mask to dilate. - kernel_size (int, optional): The size of the kernel. Default is 5. - iterations (int, optional): The number of dilation iterations. Default is 1. + h, w = image.shape[:2] + center = (w // 2, h // 2) - Returns: - CV2Img: The dilated mask. - """ - kernel = np.ones((kernel_size, kernel_size), np.uint8) - dilated_mask = cv2.dilate(mask, kernel, iterations=iterations) - return dilated_mask - - -def get_face_mask(aimg: CV2ImgU8, bgr_fake: CV2ImgU8) -> CV2ImgU8: - """ - Generates a face mask by performing bitwise OR on two face masks and then dilating the result. - - Args: - aimg (CV2Img): Input image for generating the first face mask. - bgr_fake (CV2Img): Input image for generating the second face mask. - - Returns: - CV2Img: The combined and dilated face mask. - """ - mask1 = generate_face_mask(aimg, device=shared.device) - mask2 = generate_face_mask(bgr_fake, device=shared.device) - mask = dilate_mask(cv2.bitwise_or(mask1, mask2)) - return mask + M = cv2.getRotationMatrix2D(center, angle, 1.0) + return cv2.warpAffine(image, M, (w, h))