From 1b706db7677cb591b21695bf885035560c0c29b6 Mon Sep 17 00:00:00 2001 From: Art Gourieff <85128026+Gourieff@users.noreply.github.com> Date: Sun, 7 Apr 2024 01:17:13 +0700 Subject: [PATCH] FIX: Insightface default_providers patch --- scripts/console_log_patch.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/scripts/console_log_patch.py b/scripts/console_log_patch.py index 362dec6..ca4f8e9 100644 --- a/scripts/console_log_patch.py +++ b/scripts/console_log_patch.py @@ -2,7 +2,7 @@ import os.path as osp import glob import logging import insightface -from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession +from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession, get_default_providers from insightface.model_zoo.retinaface import RetinaFace from insightface.model_zoo.landmark import Landmark from insightface.model_zoo.attribute import Attribute @@ -97,15 +97,20 @@ def patched_inswapper_init(self, model_file=None, session=None): self.input_size = tuple(input_shape[2:4][::-1]) -def patch_insightface(get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init): +def patched_get_default_providers(): + return ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] + + +def patch_insightface(get_default_providers, get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init): + insightface.model_zoo.model_zoo.get_default_providers = get_default_providers insightface.model_zoo.model_zoo.ModelRouter.get_model = get_model insightface.app.FaceAnalysis.__init__ = faceanalysis_init insightface.app.FaceAnalysis.prepare = faceanalysis_prepare insightface.model_zoo.inswapper.INSwapper.__init__ = inswapper_init -original_functions = [ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__] -patched_functions = [patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init] +original_functions = [patched_get_default_providers, ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__] +patched_functions = [patched_get_default_providers, patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init] def apply_logging_patch(console_logging_level):