UPDATE: async api try-1, XYZ init

This commit is contained in:
Art Gourieff 2024-04-07 01:21:40 +07:00
parent 1b706db767
commit 97598387b1
2 changed files with 47 additions and 16 deletions

View File

@ -1,7 +1,7 @@
''' '''
Thanks SpenserCai for the original version of the roop api script Thanks SpenserCai for the original version of the roop api script
----------------------------------- -----------------------------------
--- ReActor External API v1.0.7 --- --- ReActor External API v1.0.8a ---
----------------------------------- -----------------------------------
''' '''
import os, glob import os, glob
@ -13,6 +13,11 @@ from fastapi import FastAPI, Body
# import base64 # import base64
# import numpy as np # import numpy as np
# import cv2 # import cv2
import asyncio
from concurrent.futures import ThreadPoolExecutor
# from concurrent.futures.process import ProcessPoolExecutor
# from contextlib import asynccontextmanager
# import multiprocessing
# from modules.api.models import * # from modules.api.models import *
from modules import scripts, shared from modules import scripts, shared
@ -24,14 +29,28 @@ from scripts.reactor_swapper import EnhancementOptions, swap_face, DetectionOpti
from scripts.reactor_logger import logger from scripts.reactor_logger import logger
from scripts.reactor_helpers import get_facemodels from scripts.reactor_helpers import get_facemodels
# XYZ init:
from scripts.reactor_xyz import run # @asynccontextmanager
try: # async def lifespan(app: FastAPI):
import modules.script_callbacks as script_callbacks # app.state.executor = ProcessPoolExecutor(max_workers=4)
script_callbacks.on_before_ui(run) # yield
# script_callbacks.on_app_started(reactor_api) # app.state.executor.shutdown()
except:
pass # app = FastAPI(lifespan=lifespan)
# def run_app(a: FastAPI):
# global app
# a = app
# return a
# _executor_tp = ThreadPoolExecutor(max_workers=8)
# def entry_point():
# _executor_pp = ProcessPoolExecutor(max_workers=8)
# pool = multiprocessing.Pool(4)
async def run_event(app, fn, *args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(app.state.executor, fn, *args)
def default_file_path(): def default_file_path():
@ -82,6 +101,7 @@ def get_full_model(model_name):
# raise HTTPException(status_code=500, detail="Invalid encoded image") from e # raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def reactor_api(_: gr.Blocks, app: FastAPI): def reactor_api(_: gr.Blocks, app: FastAPI):
app.state.executor = ThreadPoolExecutor(max_workers=8)
@app.post("/reactor/image") @app.post("/reactor/image")
async def reactor_image( async def reactor_image(
source_image: str = Body("",title="Source Face Image"), source_image: str = Body("",title="Source Face Image"),
@ -131,22 +151,26 @@ def reactor_api(_: gr.Blocks, app: FastAPI):
use_model = get_full_model(model) use_model = get_full_model(model)
if use_model is None: if use_model is None:
Exception("Model not found") Exception("Model not found")
result = swap_face(s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options)
result_img = result[0] args = [s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options]
# result,_,_ = pool.map(swap_face, *args)
result,_,_ = await run_event(app,swap_face,*args)
# result,_,_ = swap_face(s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options)
if alpha is not None: if alpha is not None:
result_img = result_img.convert("RGBA") result = result.convert("RGBA")
result_img.putalpha(alpha) result.putalpha(alpha)
if save_to_file == 1: if save_to_file == 1:
if result_file_path == "": if result_file_path == "":
result_file_path = default_file_path() result_file_path = default_file_path()
try: try:
result_img.save(result_file_path, format='PNG') file_format = os.path.split(result_file_path)[1].split(".")[1]
result.save(result_file_path, format=file_format)
logger.status("Result has been saved to: %s", result_file_path) logger.status("Result has been saved to: %s", result_file_path)
except Exception as e: except Exception as e:
logger.error("Error while saving result: %s",e) logger.error("Error while saving result: %s",e)
return {"image": api.encode_pil_to_base64(result_img)} return {"image": api.encode_pil_to_base64(result)}
@app.get("/reactor/models") @app.get("/reactor/models")
async def reactor_models(): async def reactor_models():
@ -165,7 +189,6 @@ def reactor_api(_: gr.Blocks, app: FastAPI):
try: try:
import modules.script_callbacks as script_callbacks import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(reactor_api) script_callbacks.on_app_started(reactor_api)
except: except:
pass pass

View File

@ -10,6 +10,7 @@ from scripts.reactor_helpers import (
get_facemodels get_facemodels
) )
# xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module # xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
def find_module(module_names): def find_module(module_names):
@ -84,3 +85,10 @@ def run():
xyz_grid = find_module("xyz_grid.py, xy_grid.py") xyz_grid = find_module("xyz_grid.py, xy_grid.py")
if xyz_grid: if xyz_grid:
add_axis_options(xyz_grid) add_axis_options(xyz_grid)
# XYZ init:
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_before_ui(run)
except:
pass