diff --git a/scripts/console_log_patch.py b/scripts/console_log_patch.py index 362dec6..ca4f8e9 100644 --- a/scripts/console_log_patch.py +++ b/scripts/console_log_patch.py @@ -2,7 +2,7 @@ import os.path as osp import glob import logging import insightface -from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession +from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession, get_default_providers from insightface.model_zoo.retinaface import RetinaFace from insightface.model_zoo.landmark import Landmark from insightface.model_zoo.attribute import Attribute @@ -97,15 +97,20 @@ def patched_inswapper_init(self, model_file=None, session=None): self.input_size = tuple(input_shape[2:4][::-1]) -def patch_insightface(get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init): +def patched_get_default_providers(): + return ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] + + +def patch_insightface(get_default_providers, get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init): + insightface.model_zoo.model_zoo.get_default_providers = get_default_providers insightface.model_zoo.model_zoo.ModelRouter.get_model = get_model insightface.app.FaceAnalysis.__init__ = faceanalysis_init insightface.app.FaceAnalysis.prepare = faceanalysis_prepare insightface.model_zoo.inswapper.INSwapper.__init__ = inswapper_init -original_functions = [ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__] -patched_functions = [patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init] +original_functions = [patched_get_default_providers, ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__] +patched_functions = [patched_get_default_providers, patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init] def apply_logging_patch(console_logging_level):