Added option to mask only faces to eliminate blurring artifacts around faces

This commit is contained in:
jiveabillion 2023-11-22 14:28:30 -05:00
parent 16b47606f9
commit 24c9502a3c
6 changed files with 548 additions and 96 deletions

147
scripts/entities/face.py Normal file
View File

@ -0,0 +1,147 @@
import traceback
import cv2
import numpy as np
from modules import images
from PIL import Image
from scripts.entities.rect import Point, Rect
class Face:
def __init__(self, entire_image: np.ndarray, face_area: Rect, face_margin: float, face_size: int, upscaler: str):
self.face_area = face_area
self.center = face_area.center
left, top, right, bottom = face_area.to_square()
self.left, self.top, self.right, self.bottom = self.__ensure_margin(
left, top, right, bottom, entire_image, face_margin
)
self.width = self.right - self.left
self.height = self.bottom - self.top
self.image = self.__crop_face_image(entire_image, face_size, upscaler)
self.face_size = face_size
self.scale_factor = face_size / self.width
self.face_area_on_image = self.__get_face_area_on_image()
self.landmarks_on_image = self.__get_landmarks_on_image()
def __get_face_area_on_image(self):
left = int((self.face_area.left - self.left) * self.scale_factor)
top = int((self.face_area.top - self.top) * self.scale_factor)
right = int((self.face_area.right - self.left) * self.scale_factor)
bottom = int((self.face_area.bottom - self.top) * self.scale_factor)
return self.__clip_values(left, top, right, bottom)
def __get_landmarks_on_image(self):
landmarks = []
if self.face_area.landmarks is not None:
for landmark in self.face_area.landmarks:
landmarks.append(
Point(
int((landmark.x - self.left) * self.scale_factor),
int((landmark.y - self.top) * self.scale_factor),
)
)
return landmarks
def __crop_face_image(self, entire_image: np.ndarray, face_size: int, upscaler: str):
cropped = entire_image[self.top : self.bottom, self.left : self.right, :]
if upscaler:
return images.resize_image(0, Image.fromarray(cropped), face_size, face_size, upscaler)
else:
return Image.fromarray(cv2.resize(cropped, dsize=(face_size, face_size)))
def __ensure_margin(self, left: int, top: int, right: int, bottom: int, entire_image: np.ndarray, margin: float):
entire_height, entire_width = entire_image.shape[:2]
side_length = right - left
margin = min(min(entire_height, entire_width) / side_length, margin)
diff = int((side_length * margin - side_length) / 2)
top = top - diff
bottom = bottom + diff
left = left - diff
right = right + diff
if top < 0:
bottom = bottom - top
top = 0
if left < 0:
right = right - left
left = 0
if bottom > entire_height:
top = top - (bottom - entire_height)
bottom = entire_height
if right > entire_width:
left = left - (right - entire_width)
right = entire_width
return left, top, right, bottom
def get_angle(self) -> float:
landmarks = getattr(self.face_area, "landmarks", None)
if landmarks is None:
return 0
eye1 = getattr(landmarks, "eye1", None)
eye2 = getattr(landmarks, "eye2", None)
if eye2 is None or eye1 is None:
return 0
try:
dx = eye2.x - eye1.x
dy = eye2.y - eye1.y
if dx == 0:
dx = 1
angle = np.arctan(dy / dx) * 180 / np.pi
if dx < 0:
angle = (angle + 180) % 360
return angle
except Exception:
print(traceback.format_exc())
return 0
def rotate_face_area_on_image(self, angle: float):
center = [
(self.face_area_on_image[0] + self.face_area_on_image[2]) / 2,
(self.face_area_on_image[1] + self.face_area_on_image[3]) / 2,
]
points = [
[self.face_area_on_image[0], self.face_area_on_image[1]],
[self.face_area_on_image[2], self.face_area_on_image[3]],
]
angle = np.radians(angle)
rot_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
points = np.array(points) - center
points = np.dot(points, rot_matrix.T)
points += center
left, top, right, bottom = (int(points[0][0]), int(points[0][1]), int(points[1][0]), int(points[1][1]))
left, right = (right, left) if left > right else (left, right)
top, bottom = (bottom, top) if top > bottom else (top, bottom)
width, height = right - left, bottom - top
if width < height:
left, right = left - (height - width) // 2, right + (height - width) // 2
elif height < width:
top, bottom = top - (width - height) // 2, bottom + (width - height) // 2
return self.__clip_values(left, top, right, bottom)
def __clip_values(self, *args):
result = []
for val in args:
if val < 0:
result.append(0)
elif val > self.face_size:
result.append(self.face_size)
else:
result.append(val)
return tuple(result)

78
scripts/entities/rect.py Normal file
View File

@ -0,0 +1,78 @@
from typing import Dict, NamedTuple, Tuple
import numpy as np
class Point(NamedTuple):
x: int
y: int
class Landmarks(NamedTuple):
eye1: Point
eye2: Point
nose: Point
mouth1: Point
mouth2: Point
class Rect:
def __init__(
self,
left: int,
top: int,
right: int,
bottom: int,
tag: str = "face",
landmarks: Landmarks = None,
attributes: Dict[str, str] = {},
) -> None:
self.tag = tag
self.left = left
self.top = top
self.right = right
self.bottom = bottom
self.center = int((right + left) / 2)
self.middle = int((top + bottom) / 2)
self.width = right - left
self.height = bottom - top
self.size = self.width * self.height
self.landmarks = landmarks
self.attributes = attributes
@classmethod
def from_ndarray(
cls,
face_box: np.ndarray,
tag: str = "face",
landmarks: Landmarks = None,
attributes: Dict[str, str] = {},
) -> "Rect":
left, top, right, bottom, *_ = list(map(int, face_box))
return cls(left, top, right, bottom, tag, landmarks, attributes)
def to_tuple(self) -> Tuple[int, int, int, int]:
return self.left, self.top, self.right, self.bottom
def to_square(self):
left, top, right, bottom = self.to_tuple()
width = right - left
height = bottom - top
if width % 2 == 1:
right = right + 1
width = width + 1
if height % 2 == 1:
bottom = bottom + 1
height = height + 1
diff = int(abs(width - height) / 2)
if width > height:
top = top - diff
bottom = bottom + diff
else:
left = left - diff
right = right + diff
return left, top, right, bottom

View File

@ -0,0 +1,88 @@
from typing import List, Tuple
import cv2
import modules.shared as shared
import numpy as np
import torch
from facexlib.parsing import init_parsing_model
from facexlib.utils.misc import img2tensor
from torchvision.transforms.functional import normalize
from PIL import Image
from scripts.inferencers.mask_generator import MaskGenerator
from scripts.reactor_logger import logger
class BiSeNetMaskGenerator(MaskGenerator):
def __init__(self) -> None:
self.mask_model = init_parsing_model(device=shared.device)
def name(self):
return "BiSeNet"
def generate_mask(
self,
face_image: np.ndarray,
face_area_on_image: Tuple[int, int, int, int],
affected_areas: List[str],
mask_size: int,
use_minimal_area: bool,
fallback_ratio: float = 0.25,
**kwargs,
) -> np.ndarray:
original_face_image = face_image
face_image = face_image.copy()
face_image = face_image[:, :, ::-1]
if use_minimal_area:
face_image = MaskGenerator.mask_non_face_areas(face_image, face_area_on_image)
h, w, _ = face_image.shape
if w != 512 or h != 512:
rw = (int(w * (512 / w)) // 8) * 8
rh = (int(h * (512 / h)) // 8) * 8
face_image = cv2.resize(face_image, dsize=(rw, rh))
face_tensor = img2tensor(face_image.astype("float32") / 255.0, float32=True)
normalize(face_tensor, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
face_tensor = torch.unsqueeze(face_tensor, 0).to(shared.device)
with torch.no_grad():
face = self.mask_model(face_tensor)[0]
face = face.squeeze(0).cpu().numpy().argmax(0)
face = face.copy().astype(np.uint8)
mask = self.__to_mask(face, affected_areas)
if mask_size > 0:
mask = cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=mask_size)
if w != 512 or h != 512:
mask = cv2.resize(mask, dsize=(w, h))
"""if MaskGenerator.calculate_mask_coverage(mask) < fallback_ratio:
logger.info("Use fallback mask generator")
mask = self.fallback_mask_generator.generate_mask(
original_face_image, face_area_on_image, use_minimal_area=True
)"""
return mask
def __to_mask(self, face: np.ndarray, affected_areas: List[str]) -> np.ndarray:
keep_face = "Face" in affected_areas
keep_neck = "Neck" in affected_areas
keep_hair = "Hair" in affected_areas
keep_hat = "Hat" in affected_areas
mask = np.zeros((face.shape[0], face.shape[1], 3), dtype=np.uint8)
num_of_class = np.max(face)
for i in range(1, num_of_class + 1):
index = np.where(face == i)
if i < 14 and keep_face:
mask[index[0], index[1], :] = [255, 255, 255]
elif i == 14 and keep_neck:
mask[index[0], index[1], :] = [255, 255, 255]
elif i == 17 and keep_hair:
mask[index[0], index[1], :] = [255, 255, 255]
elif i == 18 and keep_hat:
mask[index[0], index[1], :] = [255, 255, 255]
return mask

View File

@ -0,0 +1,37 @@
from abc import ABC, abstractmethod
from typing import Tuple
import cv2
import numpy as np
class MaskGenerator(ABC):
@abstractmethod
def name(self) -> str:
pass
@abstractmethod
def generate_mask(
self,
face_image: np.ndarray,
face_area_on_image: Tuple[int, int, int, int],
**kwargs,
) -> np.ndarray:
pass
@staticmethod
def mask_non_face_areas(image: np.ndarray, face_area_on_image: Tuple[int, int, int, int]) -> np.ndarray:
left, top, right, bottom = face_area_on_image
image = image.copy()
image[:top, :] = 0
image[bottom:, :] = 0
image[:, :left] = 0
image[:, right:] = 0
return image
@staticmethod
def calculate_mask_coverage(mask: np.ndarray):
gray_mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
non_black_pixels = np.count_nonzero(gray_mask)
total_pixels = gray_mask.size
return non_black_pixels / total_pixels

View File

@ -66,6 +66,8 @@ class FaceSwapScript(scripts.Script):
img = gr.Image(type="pil")
enable = gr.Checkbox(False, label="Enable", info=f"The Fast and Simple FaceSwap Extension - {version_flag}")
save_original = gr.Checkbox(False, label="Save Original", info="Save the original image(s) made before swapping; If you use \"img2img\" - this option will affect with \"Swap in generated\" only")
mask_face = gr.Checkbox(False, label="Mask Faces", info="Attempt to mask only the faces and eliminate pixelation of the image around the contours.")
gr.Markdown("<br>")
gr.Markdown("Source Image (above):")
with gr.Row():
@ -211,6 +213,7 @@ class FaceSwapScript(scripts.Script):
source_hash_check,
target_hash_check,
device,
mask_face
]
@ -264,6 +267,7 @@ class FaceSwapScript(scripts.Script):
source_hash_check,
target_hash_check,
device,
mask_face
):
self.enable = enable
if self.enable:
@ -291,6 +295,7 @@ class FaceSwapScript(scripts.Script):
self.source_hash_check = source_hash_check
self.target_hash_check = target_hash_check
self.device = device
self.mask_face = mask_face
if self.gender_source is None or self.gender_source == "No":
self.gender_source = 0
if self.gender_target is None or self.gender_target == "No":
@ -334,6 +339,7 @@ class FaceSwapScript(scripts.Script):
source_hash_check=self.source_hash_check,
target_hash_check=self.target_hash_check,
device=self.device,
mask_face=mask_face
)
p.init_images[i] = result
# result_path = get_image_path(p.init_images[i], p.outpath_samples, "", p.all_seeds[i], p.all_prompts[i], "txt", p=p, suffix="-swapped")
@ -385,6 +391,7 @@ class FaceSwapScript(scripts.Script):
source_hash_check=self.source_hash_check,
target_hash_check=self.target_hash_check,
device=self.device,
mask_face=self.mask_face
)
if result is not None and swapped > 0:
result_images.append(result)
@ -442,6 +449,7 @@ class FaceSwapScript(scripts.Script):
source_hash_check=self.source_hash_check,
target_hash_check=self.target_hash_check,
device=self.device,
mask_face=self.mask_face
)
try:
pp = scripts_postprocessing.PostprocessedImage(result)
@ -468,6 +476,8 @@ class FaceSwapScriptExtras(scripts_postprocessing.ScriptPostprocessing):
with gr.Column():
img = gr.Image(type="pil")
enable = gr.Checkbox(False, label="Enable", info=f"The Fast and Simple FaceSwap Extension - {version_flag}")
mask_face = gr.Checkbox(False, label="Mask Faces", info="Attempt to mask only the faces and eliminate pixelation of the image around the contours.")
gr.Markdown("Source Image (above):")
with gr.Row():
source_faces_index = gr.Textbox(
@ -582,6 +592,7 @@ class FaceSwapScriptExtras(scripts_postprocessing.ScriptPostprocessing):
'gender_target': gender_target,
'codeformer_weight': codeformer_weight,
'device': device,
'mask_face':mask_face
}
return args
@ -631,6 +642,7 @@ class FaceSwapScriptExtras(scripts_postprocessing.ScriptPostprocessing):
self.gender_target = args['gender_target']
self.codeformer_weight = args['codeformer_weight']
self.device = args['device']
self.mask_face = args['mask_face']
if self.gender_source is None or self.gender_source == "No":
self.gender_source = 0
if self.gender_target is None or self.gender_target == "No":
@ -669,6 +681,7 @@ class FaceSwapScriptExtras(scripts_postprocessing.ScriptPostprocessing):
source_hash_check=True,
target_hash_check=True,
device=self.device,
mask_face=self.mask_face
)
try:
pp.info["ReActor"] = True

View File

@ -2,13 +2,16 @@ import copy
import os
from dataclasses import dataclass
from typing import List, Union
from typing import Tuple
import cv2
import numpy as np
from PIL import Image
from insightface.app.common import Face as IFace
from numpy import uint8
from PIL import Image, ImageDraw
from scripts.inferencers.bisenet_mask_generator import BiSeNetMaskGenerator
from scripts.entities.face import Face
from scripts.entities.rect import Rect
import insightface
from torchvision.transforms.functional import to_pil_image
from scripts.reactor_helpers import get_image_md5hash, get_Device
from modules.face_restoration import FaceRestoration
try: # A1111
@ -18,6 +21,7 @@ except: # SD.Next
from modules.upscaler import UpscalerData
from modules.shared import state
from scripts.reactor_logger import logger
try:
from modules.paths_internal import models_path
except:
@ -28,16 +32,6 @@ except:
import warnings
PILImage = Image.Image
CV2ImgU8 = np.ndarray[int, np.dtype[uint8]]
Face = IFace
BoxCoords = Tuple[int, int, int, int]
class Gender(Enum):
AUTO = -1
FEMALE = 0
MALE = 1
np.warnings = warnings
np.warnings.filterwarnings('ignore')
@ -86,8 +80,9 @@ def check_process_halt(msgforced: bool = False):
FS_MODEL = None
MASK_MODEL = None
CURRENT_FS_MODEL_PATH = None
CURRENT_MASK_MODEL_PATH = None
ANALYSIS_MODEL = None
SOURCE_FACES = None
@ -115,6 +110,8 @@ def getFaceSwapModel(model_path: str):
return FS_MODEL
def restore_face(image: Image, enhancement_options: EnhancementOptions):
result_image = image
@ -178,7 +175,28 @@ def enhance_image(image: Image, enhancement_options: EnhancementOptions):
result_image = restore_face(result_image, enhancement_options)
return result_image
def enhance_image_and_mask(image: Image.Image, enhancement_options: EnhancementOptions,target_img_orig:Image.Image,entire_mask_image:Image.Image)->Image.Image:
result_image = image
if check_process_halt(msgforced=True):
return result_image
if enhancement_options.do_restore_first:
result_image = restore_face(result_image, enhancement_options)
result_image = Image.composite(result_image,target_img_orig,entire_mask_image)
result_image = upscale_image(result_image, enhancement_options)
else:
result_image = upscale_image(result_image, enhancement_options)
entire_mask_image = Image.fromarray(cv2.resize(np.array(entire_mask_image),result_image.size, interpolation=cv2.INTER_AREA)).convert("L")
result_image = Image.composite(result_image,target_img_orig,entire_mask_image)
result_image = restore_face(result_image, enhancement_options)
return result_image
def get_gender(face, face_index):
gender = [
x.sex
@ -292,6 +310,7 @@ def swap_face(
source_hash_check: bool = True,
target_hash_check: bool = False,
device: str = "CPU",
mask_face:bool = False
):
global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH, PROVIDERS
result_image = target_img
@ -318,7 +337,8 @@ def swap_face(
source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
target_img_orig = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
entire_mask_image = np.zeros_like(np.array(target_img))
output: List = []
output_info: str = ""
swapped = 0
@ -421,7 +441,12 @@ def swap_face(
if target_face is not None and wrong_gender == 0:
logger.status("Swapping Source into Target")
result = face_swapper.get(result, target_face, source_face)
swapped_image = face_swapper.get(result, target_face, source_face)
if mask_face:
result = apply_face_mask(swapped_image=swapped_image,target_image=result,target_face=target_face,entire_mask_image=entire_mask_image)
else:
result = swapped_image
swapped += 1
elif wrong_gender == 1:
@ -455,8 +480,13 @@ def swap_face(
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
if enhancement_options is not None and swapped > 0:
result_image = enhance_image(result_image, enhancement_options)
if mask_face and entire_mask_image is not None:
enhance_image_and_mask(result_image, enhancement_options,Image.fromarray(target_img_orig),Image.fromarray(entire_mask_image).convert("L"))
else:
result_image = enhance_image(result_image, enhancement_options)
elif mask_face and entire_mask_image is not None and swapped > 0:
result_image = Image.composite(result_image,Image.fromarray(target_img_orig),Image.fromarray(entire_mask_image).convert("L"))
else:
logger.status("No source face(s) in the provided Index")
else:
@ -465,99 +495,158 @@ def swap_face(
return result_image, output, swapped
def merge_images_with_mask(
image1: CV2ImgU8, image2: CV2ImgU8, mask: CV2ImgU8
) -> CV2ImgU8:
def apply_face_mask(swapped_image:np.ndarray,target_image:np.ndarray,target_face,entire_mask_image:np.array)->np.ndarray:
logger.status("Masking Face")
mask_generator = BiSeNetMaskGenerator()
face = Face(target_image,Rect.from_ndarray(np.array(target_face.bbox)),1.6,512,"")
face_image = np.array(face.image)
process_face_image(face)
face_area_on_image = face.face_area_on_image
mask = mask_generator.generate_mask(face_image,face_area_on_image=face_area_on_image,affected_areas=["Face"],mask_size=0,use_minimal_area=True)
mask = cv2.blur(mask, (12, 12))
"""entire_mask_image = np.zeros_like(target_image)"""
larger_mask = cv2.resize(mask, dsize=(face.width, face.height))
entire_mask_image[
face.top : face.bottom,
face.left : face.right,
] = larger_mask
result = Image.composite(Image.fromarray(swapped_image),Image.fromarray(target_image), Image.fromarray(entire_mask_image).convert("L"))
return np.array(result)
def correct_face_tilt(angle: float) -> bool:
angle = abs(angle)
if angle > 180:
angle = 360 - angle
return angle > 40
def _dilate(arr: np.ndarray, value: int) -> np.ndarray:
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
return cv2.dilate(arr, kernel, iterations=1)
def _erode(arr: np.ndarray, value: int) -> np.ndarray:
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
return cv2.erode(arr, kernel, iterations=1)
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(255, 255, 255),
(128, 0, 0),
(0, 128, 0),
(128, 128, 0),
(0, 0, 128),
(0, 128, 128),
]
def color_generator(colors):
while True:
for color in colors:
yield color
color_iter = color_generator(colors)
def process_face_image(
face: Face,
**kwargs,
) -> Image:
image = np.array(face.image)
overlay = image.copy()
cv2.rectangle(overlay, (0, 0), (image.shape[1], image.shape[0]), next(color_iter), -1)
l, t, r, b = face.face_area_on_image
cv2.rectangle(overlay, (l, t), (r, b), (0, 0, 0), 10)
if face.landmarks_on_image is not None:
for landmark in face.landmarks_on_image:
cv2.circle(overlay, (int(landmark.x), int(landmark.y)), 6, (0, 0, 0), 10)
alpha = 0.3
output = cv2.addWeighted(image, 1 - alpha, overlay, alpha, 0)
return Image.fromarray(output)
def dilate_erode(img: Image.Image, value: int) -> Image.Image:
"""
Merges two images using a given mask. The regions where the mask is set will be replaced with the corresponding
areas of the second image.
The dilate_erode function takes an image and a value.
If the value is positive, it dilates the image by that amount.
If the value is negative, it erodes the image by that amount.
Args:
image1 (CV2Img): The base image, which must have the same shape as image2.
image2 (CV2Img): The image to be merged, which must have the same shape as image1.
mask (CV2Img): A binary mask specifying the regions to be merged. The mask shape should match image1's first two dimensions.
Parameters
----------
img: PIL.Image.Image
the image to be processed
value: int
kernel size of dilation or erosion
Returns:
CV2Img: The merged image.
Raises:
ValueError: If the shapes of the images and mask do not match.
Returns
-------
PIL.Image.Image
The image that has been dilated or eroded
"""
if value == 0:
return img
if image1.shape != image2.shape or image1.shape[:2] != mask.shape:
raise ValueError("Img should have the same shape")
mask = mask.astype(np.uint8)
masked_region = cv2.bitwise_and(image2, image2, mask=mask)
inverse_mask = cv2.bitwise_not(mask)
empty_region = cv2.bitwise_and(image1, image1, mask=inverse_mask)
merged_image = cv2.add(empty_region, masked_region)
return merged_image
arr = np.array(img)
arr = _dilate(arr, value) if value > 0 else _erode(arr, -value)
return Image.fromarray(arr)
def erode_mask(mask: CV2ImgU8, kernel_size: int = 3, iterations: int = 1) -> CV2ImgU8:
def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
"""
Erodes a binary mask using a given kernel size and number of iterations.
Parameters
----------
masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
The device can be CUDA, but `to_pil_image` takes care of that.
Args:
mask (CV2Img): The binary mask to erode.
kernel_size (int, optional): The size of the kernel. Default is 3.
iterations (int, optional): The number of erosion iterations. Default is 1.
Returns:
CV2Img: The eroded mask.
shape: tuple[int, int]
(width, height) of the original image
"""
kernel = np.ones((kernel_size, kernel_size), np.uint8)
eroded_mask = cv2.erode(mask, kernel, iterations=iterations)
return eroded_mask
n = masks.shape[0]
return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
def apply_gaussian_blur(
mask: CV2ImgU8, kernel_size: Tuple[int, int] = (5, 5), sigma_x: int = 0
) -> CV2ImgU8:
def create_mask_from_bbox(
bboxes: list[list[float]], shape: tuple[int, int]
) -> list[Image.Image]:
"""
Applies a Gaussian blur to a mask.
Parameters
----------
bboxes: list[list[float]]
list of [x1, y1, x2, y2]
bounding boxes
shape: tuple[int, int]
shape of the image (width, height)
Args:
mask (CV2Img): The mask to blur.
kernel_size (tuple, optional): The size of the kernel, e.g. (5, 5). Default is (5, 5).
sigma_x (int, optional): The standard deviation in the X direction. Default is 0.
Returns
-------
masks: list[Image.Image]
A list of masks
Returns:
CV2Img: The blurred mask.
"""
blurred_mask = cv2.GaussianBlur(mask, kernel_size, sigma_x)
return blurred_mask
masks = []
for bbox in bboxes:
mask = Image.new("L", shape, 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.rectangle(bbox, fill=255)
masks.append(mask)
return masks
def rotate_image(image: Image, angle: float) -> Image:
if angle == 0:
return image
return Image.fromarray(rotate_array(np.array(image), angle))
def dilate_mask(mask: CV2ImgU8, kernel_size: int = 5, iterations: int = 1) -> CV2ImgU8:
"""
Dilates a binary mask using a given kernel size and number of iterations.
def rotate_array(image: np.ndarray, angle: float) -> np.ndarray:
if angle == 0:
return image
Args:
mask (CV2Img): The binary mask to dilate.
kernel_size (int, optional): The size of the kernel. Default is 5.
iterations (int, optional): The number of dilation iterations. Default is 1.
h, w = image.shape[:2]
center = (w // 2, h // 2)
Returns:
CV2Img: The dilated mask.
"""
kernel = np.ones((kernel_size, kernel_size), np.uint8)
dilated_mask = cv2.dilate(mask, kernel, iterations=iterations)
return dilated_mask
def get_face_mask(aimg: CV2ImgU8, bgr_fake: CV2ImgU8) -> CV2ImgU8:
"""
Generates a face mask by performing bitwise OR on two face masks and then dilating the result.
Args:
aimg (CV2Img): Input image for generating the first face mask.
bgr_fake (CV2Img): Input image for generating the second face mask.
Returns:
CV2Img: The combined and dilated face mask.
"""
mask1 = generate_face_mask(aimg, device=shared.device)
mask2 = generate_face_mask(bgr_fake, device=shared.device)
mask = dilate_mask(cv2.bitwise_or(mask1, mask2))
return mask
M = cv2.getRotationMatrix2D(center, angle, 1.0)
return cv2.warpAffine(image, M, (w, h))