UPDATE: API RGBA support, GET facemodels show

FR: #394
D: #384
This commit is contained in:
Art Gourieff 2024-03-12 20:25:33 +07:00
parent 0185d7a2af
commit 9989610fa7
2 changed files with 41 additions and 4 deletions

View File

@ -1,12 +1,18 @@
''' '''
Thanks SpenserCai for the original version of the roop api script Thanks SpenserCai for the original version of the roop api script
----------------------------------- -----------------------------------
--- ReActor External API v1.0.6 --- --- ReActor External API v1.0.7 ---
----------------------------------- -----------------------------------
''' '''
import os, glob import os, glob
from datetime import datetime, date from datetime import datetime, date
from fastapi import FastAPI, Body from fastapi import FastAPI, Body
# from fastapi.exceptions import HTTPException
# from io import BytesIO
# from PIL import Image
# import base64
# import numpy as np
# import cv2
# from modules.api.models import * # from modules.api.models import *
from modules import scripts, shared from modules import scripts, shared
@ -16,6 +22,7 @@ import gradio as gr
from scripts.reactor_swapper import EnhancementOptions, swap_face, DetectionOptions from scripts.reactor_swapper import EnhancementOptions, swap_face, DetectionOptions
from scripts.reactor_logger import logger from scripts.reactor_logger import logger
from scripts.reactor_helpers import get_facemodels
# XYZ init: # XYZ init:
from scripts.reactor_xyz import run from scripts.reactor_xyz import run
@ -61,6 +68,19 @@ def get_full_model(model_name):
return model return model
return None return None
# def decode_base64_to_image_rgba(encoding):
# if encoding.startswith("data:image/"):
# encoding = encoding.split(";")[1].split(",")[1]
# try:
# im_bytes = base64.b64decode(encoding)
# im_arr = np.frombuffer(im_bytes, dtype=np.uint8) # im_arr is one-dim Numpy array
# img = cv2.imdecode(im_arr, flags=cv2.IMREAD_UNCHANGED)
# img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
# image = Image.fromarray(img, mode="RGBA")
# return image
# except Exception as e:
# raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def reactor_api(_: gr.Blocks, app: FastAPI): def reactor_api(_: gr.Blocks, app: FastAPI):
@app.post("/reactor/image") @app.post("/reactor/image")
async def reactor_image( async def reactor_image(
@ -92,6 +112,12 @@ def reactor_api(_: gr.Blocks, app: FastAPI):
): ):
s_image = api.decode_base64_to_image(source_image) if select_source == 0 else None s_image = api.decode_base64_to_image(source_image) if select_source == 0 else None
t_image = api.decode_base64_to_image(target_image) t_image = api.decode_base64_to_image(target_image)
if t_image.mode == 'RGBA':
_, _, _, alpha = t_image.split()
else:
alpha = None
sf_index = source_faces_index sf_index = source_faces_index
f_index = face_index f_index = face_index
gender_s = gender_source gender_s = gender_source
@ -106,15 +132,21 @@ def reactor_api(_: gr.Blocks, app: FastAPI):
if use_model is None: if use_model is None:
Exception("Model not found") Exception("Model not found")
result = swap_face(s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options) result = swap_face(s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options)
result_img = result[0]
if alpha is not None:
result_img = result_img.convert("RGBA")
result_img.putalpha(alpha)
if save_to_file == 1: if save_to_file == 1:
if result_file_path == "": if result_file_path == "":
result_file_path = default_file_path() result_file_path = default_file_path()
try: try:
result[0].save(result_file_path, format='PNG') result_img.save(result_file_path, format='PNG')
logger.status("Result has been saved to: %s", result_file_path) logger.status("Result has been saved to: %s", result_file_path)
except Exception as e: except Exception as e:
logger.error("Error while saving result: %s",e) logger.error("Error while saving result: %s",e)
return {"image": api.encode_pil_to_base64(result[0])} return {"image": api.encode_pil_to_base64(result_img)}
@app.get("/reactor/models") @app.get("/reactor/models")
async def reactor_models(): async def reactor_models():
@ -126,6 +158,11 @@ def reactor_api(_: gr.Blocks, app: FastAPI):
names = [upscaler.name for upscaler in shared.sd_upscalers] names = [upscaler.name for upscaler in shared.sd_upscalers]
return {"upscalers": names} return {"upscalers": names}
@app.get("/reactor/facemodels")
async def reactor_facemodels():
facemodels = [os.path.split(model)[1].split(".")[0] for model in get_facemodels()]
return {"facemodels": facemodels}
try: try:
import modules.script_callbacks as script_callbacks import modules.script_callbacks as script_callbacks