diff --git a/scripts/reactor_api.py b/scripts/reactor_api.py index 5fb9969..f7d3553 100644 --- a/scripts/reactor_api.py +++ b/scripts/reactor_api.py @@ -1,7 +1,7 @@ ''' Thanks SpenserCai for the original version of the roop api script ----------------------------------- ---- ReActor External API v1.0.7 --- +--- ReActor External API v1.0.8a --- ----------------------------------- ''' import os, glob @@ -13,6 +13,11 @@ from fastapi import FastAPI, Body # import base64 # import numpy as np # import cv2 +import asyncio +from concurrent.futures import ThreadPoolExecutor +# from concurrent.futures.process import ProcessPoolExecutor +# from contextlib import asynccontextmanager +# import multiprocessing # from modules.api.models import * from modules import scripts, shared @@ -24,14 +29,28 @@ from scripts.reactor_swapper import EnhancementOptions, swap_face, DetectionOpti from scripts.reactor_logger import logger from scripts.reactor_helpers import get_facemodels -# XYZ init: -from scripts.reactor_xyz import run -try: - import modules.script_callbacks as script_callbacks - script_callbacks.on_before_ui(run) - # script_callbacks.on_app_started(reactor_api) -except: - pass + +# @asynccontextmanager +# async def lifespan(app: FastAPI): +# app.state.executor = ProcessPoolExecutor(max_workers=4) +# yield +# app.state.executor.shutdown() + +# app = FastAPI(lifespan=lifespan) + +# def run_app(a: FastAPI): +# global app +# a = app +# return a + +# _executor_tp = ThreadPoolExecutor(max_workers=8) +# def entry_point(): +# _executor_pp = ProcessPoolExecutor(max_workers=8) +# pool = multiprocessing.Pool(4) + +async def run_event(app, fn, *args): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(app.state.executor, fn, *args) def default_file_path(): @@ -82,6 +101,7 @@ def get_full_model(model_name): # raise HTTPException(status_code=500, detail="Invalid encoded image") from e def reactor_api(_: gr.Blocks, app: FastAPI): + app.state.executor = ThreadPoolExecutor(max_workers=8) @app.post("/reactor/image") async def reactor_image( source_image: str = Body("",title="Source Face Image"), @@ -131,22 +151,26 @@ def reactor_api(_: gr.Blocks, app: FastAPI): use_model = get_full_model(model) if use_model is None: Exception("Model not found") - result = swap_face(s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options) - result_img = result[0] + + args = [s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options] + # result,_,_ = pool.map(swap_face, *args) + result,_,_ = await run_event(app,swap_face,*args) + # result,_,_ = swap_face(s_image, t_image, use_model, sf_index, f_index, up_options, gender_s, gender_t, True, True, device, mask_face, select_source, face_model, source_folder, None, random_image,det_options) if alpha is not None: - result_img = result_img.convert("RGBA") - result_img.putalpha(alpha) + result = result.convert("RGBA") + result.putalpha(alpha) if save_to_file == 1: if result_file_path == "": result_file_path = default_file_path() try: - result_img.save(result_file_path, format='PNG') + file_format = os.path.split(result_file_path)[1].split(".")[1] + result.save(result_file_path, format=file_format) logger.status("Result has been saved to: %s", result_file_path) except Exception as e: logger.error("Error while saving result: %s",e) - return {"image": api.encode_pil_to_base64(result_img)} + return {"image": api.encode_pil_to_base64(result)} @app.get("/reactor/models") async def reactor_models(): @@ -165,7 +189,6 @@ def reactor_api(_: gr.Blocks, app: FastAPI): try: import modules.script_callbacks as script_callbacks - script_callbacks.on_app_started(reactor_api) except: pass diff --git a/scripts/reactor_xyz.py b/scripts/reactor_xyz.py index 75c7e57..0b101f3 100644 --- a/scripts/reactor_xyz.py +++ b/scripts/reactor_xyz.py @@ -10,6 +10,7 @@ from scripts.reactor_helpers import ( get_facemodels ) + # xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module def find_module(module_names): @@ -84,3 +85,10 @@ def run(): xyz_grid = find_module("xyz_grid.py, xy_grid.py") if xyz_grid: add_axis_options(xyz_grid) + +# XYZ init: +try: + import modules.script_callbacks as script_callbacks + script_callbacks.on_before_ui(run) +except: + pass