From 9989610fa777ff5b0930811a66242744a9101b22 Mon Sep 17 00:00:00 2001 From: Art Gourieff <85128026+Gourieff@users.noreply.github.com> Date: Tue, 12 Mar 2024 20:25:33 +0700 Subject: [PATCH] UPDATE: API RGBA support, GET facemodels show FR: #394 D: #384 --- example/api_example.py | 2 +- scripts/reactor_api.py | 43 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/example/api_example.py b/example/api_example.py index d3bc64a..e69766b 100644 --- a/example/api_example.py +++ b/example/api_example.py @@ -17,7 +17,7 @@ finally: print(im) img_bytes = io.BytesIO() -im.save(img_bytes, format='PNG') +im.save(img_bytes, format='PNG') img_base64 = base64.b64encode(img_bytes.getvalue()).decode('utf-8') # ReActor arguments: diff --git a/scripts/reactor_api.py b/scripts/reactor_api.py index 20bb925..5fb9969 100644 --- a/scripts/reactor_api.py +++ b/scripts/reactor_api.py @@ -1,12 +1,18 @@ ''' Thanks SpenserCai for the original version of the roop api script ----------------------------------- ---- ReActor External API v1.0.6 --- +--- ReActor External API v1.0.7 --- ----------------------------------- ''' import os, glob from datetime import datetime, date from fastapi import FastAPI, Body +# from fastapi.exceptions import HTTPException +# from io import BytesIO +# from PIL import Image +# import base64 +# import numpy as np +# import cv2 # from modules.api.models import * from modules import scripts, shared @@ -16,6 +22,7 @@ import gradio as gr from scripts.reactor_swapper import EnhancementOptions, swap_face, DetectionOptions from scripts.reactor_logger import logger +from scripts.reactor_helpers import get_facemodels # XYZ init: from scripts.reactor_xyz import run @@ -61,6 +68,19 @@ def get_full_model(model_name): return model return None +# def decode_base64_to_image_rgba(encoding): +# if encoding.startswith("data:image/"): +# encoding = encoding.split(";")[1].split(",")[1] +# try: +# im_bytes = base64.b64decode(encoding) +# im_arr = np.frombuffer(im_bytes, dtype=np.uint8) # im_arr is one-dim Numpy array +# img = cv2.imdecode(im_arr, flags=cv2.IMREAD_UNCHANGED) +# img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) +# image = Image.fromarray(img, mode="RGBA") +# return image +# except Exception as e: +# raise HTTPException(status_code=500, detail="Invalid encoded image") from e + def reactor_api(_: gr.Blocks, app: FastAPI): @app.post("/reactor/image") async def reactor_image( @@ -92,6 +112,12 @@ def reactor_api(_: gr.Blocks, app: FastAPI): ): s_image = api.decode_base64_to_image(source_image) if select_source == 0 else None t_image = api.decode_base64_to_image(target_image) + + if t_image.mode == 'RGBA': + _, _, _, alpha = t_image.split() + else: + alpha = None + sf_index = source_faces_index f_index = face_index gender_s = gender_source @@ -106,15 +132,21 @@ def reactor_api(_: gr.Blocks, app: FastAPI): if use_model is None: Exception("Model not found") result = swap_face(s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options) + result_img = result[0] + + if alpha is not None: + result_img = result_img.convert("RGBA") + result_img.putalpha(alpha) + if save_to_file == 1: if result_file_path == "": result_file_path = default_file_path() try: - result[0].save(result_file_path, format='PNG') + result_img.save(result_file_path, format='PNG') logger.status("Result has been saved to: %s", result_file_path) except Exception as e: logger.error("Error while saving result: %s",e) - return {"image": api.encode_pil_to_base64(result[0])} + return {"image": api.encode_pil_to_base64(result_img)} @app.get("/reactor/models") async def reactor_models(): @@ -125,6 +157,11 @@ def reactor_api(_: gr.Blocks, app: FastAPI): async def reactor_upscalers(): names = [upscaler.name for upscaler in shared.sd_upscalers] return {"upscalers": names} + + @app.get("/reactor/facemodels") + async def reactor_facemodels(): + facemodels = [os.path.split(model)[1].split(".")[0] for model in get_facemodels()] + return {"facemodels": facemodels} try: import modules.script_callbacks as script_callbacks