UPDATE: More flexible face index selection

Merge with small fixes of PR https://github.com/s0md3v/sd-webui-roop/pull/72
Thanks @justinh24 for the idea and implementation
This commit is contained in:
Gourieff 2023-06-30 13:37:40 +07:00
parent cd7f935347
commit d27a203a56
3 changed files with 40 additions and 14 deletions

View File

@ -36,10 +36,15 @@ class FaceSwapScript(scripts.Script):
with gr.Column(): with gr.Column():
img = gr.inputs.Image(type="pil") img = gr.inputs.Image(type="pil")
enable = gr.Checkbox(False, placeholder="enable", label="Enable") enable = gr.Checkbox(False, placeholder="enable", label="Enable")
source_faces_index = gr.Textbox(
value="0",
placeholder="Which face(s) to use as source (comma separated)",
label="Comma separated face number(s) from swap-source image (above)",
)
faces_index = gr.Textbox( faces_index = gr.Textbox(
value="0", value="0",
placeholder="Which face to swap (comma separated), start from 0", placeholder="Which face to swap (comma separated)",
label="Comma separated face number(s)", label="Comma separated face number(s) for target image (result)",
) )
with gr.Row(): with gr.Row():
face_restorer_name = gr.Radio( face_restorer_name = gr.Radio(
@ -90,6 +95,7 @@ class FaceSwapScript(scripts.Script):
return [ return [
img, img,
enable, enable,
source_faces_index,
faces_index, faces_index,
model, model,
face_restorer_name, face_restorer_name,
@ -130,6 +136,7 @@ class FaceSwapScript(scripts.Script):
p: StableDiffusionProcessing, p: StableDiffusionProcessing,
img, img,
enable, enable,
source_faces_index,
faces_index, faces_index,
model, model,
face_restorer_name, face_restorer_name,
@ -149,11 +156,16 @@ class FaceSwapScript(scripts.Script):
self.upscaler_name = upscaler_name self.upscaler_name = upscaler_name
self.swap_in_generated = swap_in_generated self.swap_in_generated = swap_in_generated
self.model = model self.model = model
self.faces_index = { self.source_faces_index = [
int(x) for x in source_faces_index.strip(",").split(",") if x.isnumeric()
]
self.faces_index = [
int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() int(x) for x in faces_index.strip(",").split(",") if x.isnumeric()
} ]
if len(self.source_faces_index) == 0:
self.source_faces_index = [0]
if len(self.faces_index) == 0: if len(self.faces_index) == 0:
self.faces_index = {0} self.faces_index = [0]
if self.enable: if self.enable:
if self.source is not None: if self.source is not None:
if isinstance(p, StableDiffusionProcessingImg2Img) and swap_in_source: if isinstance(p, StableDiffusionProcessingImg2Img) and swap_in_source:
@ -164,6 +176,7 @@ class FaceSwapScript(scripts.Script):
result = swap_face( result = swap_face(
self.source, self.source,
p.init_images[i], p.init_images[i],
source_faces_index=self.source_faces_index,
faces_index=self.faces_index, faces_index=self.faces_index,
model=self.model, model=self.model,
upscale_options=self.upscale_options, upscale_options=self.upscale_options,
@ -183,6 +196,7 @@ class FaceSwapScript(scripts.Script):
result: ImageResult = swap_face( result: ImageResult = swap_face(
self.source, self.source,
image, image,
source_faces_index=self.source_faces_index,
faces_index=self.faces_index, faces_index=self.faces_index,
model=self.model, model=self.model,
upscale_options=self.upscale_options, upscale_options=self.upscale_options,

View File

@ -1,4 +1,4 @@
version_flag = "v0.0.3" version_flag = "v0.0.4"
from scripts.roop_logging import logger from scripts.roop_logging import logger

View File

@ -120,7 +120,8 @@ def swap_face(
source_img: Image.Image, source_img: Image.Image,
target_img: Image.Image, target_img: Image.Image,
model: Union[str, None] = None, model: Union[str, None] = None,
faces_index: Set[int] = {0}, source_faces_index: List[int] = [0],
faces_index: List[int] = [0],
upscale_options: Union[UpscaleOptions, None] = None, upscale_options: Union[UpscaleOptions, None] = None,
) -> ImageResult: ) -> ImageResult:
result_image = target_img result_image = target_img
@ -128,24 +129,35 @@ def swap_face(
if model is not None: if model is not None:
source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
source_face = get_face_single(source_img, face_index=0) source_face = get_face_single(source_img, face_index=source_faces_index[0])
if source_face is not None: if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index):
logger.info(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.')
elif source_face is not None:
result = target_img result = target_img
model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model) model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model)
face_swapper = getFaceSwapModel(model_path) face_swapper = getFaceSwapModel(model_path)
source_face_idx = 0
for face_num in faces_index: for face_num in faces_index:
target_face = get_face_single(target_img, face_index=face_num) if len(source_faces_index) > 1 and source_face_idx > 0:
if target_face is not None: source_face = get_face_single(source_img, face_index=source_faces_index[source_face_idx])
result = face_swapper.get(result, target_face, source_face) source_face_idx += 1
if source_face is not None:
target_face = get_face_single(target_img, face_index=face_num)
if target_face is not None:
result = face_swapper.get(result, target_face, source_face)
else:
logger.info(f"No target face found for {face_num}")
else: else:
logger.info(f"No target face found for {face_num}") logger.info(f"No source face found for face number {source_face_idx}.")
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
if upscale_options is not None: if upscale_options is not None:
result_image = upscale_image(result_image, upscale_options) result_image = upscale_image(result_image, upscale_options)
else: else:
logger.info("No source face found") logger.info("No source face(s) found")
result_image.save(fn.name) result_image.save(fn.name)
return ImageResult(path=fn.name) return ImageResult(path=fn.name)