FIX: Insightface default_providers patch
This commit is contained in:
parent
8651bc2639
commit
1b706db767
@ -2,7 +2,7 @@ import os.path as osp
|
||||
import glob
|
||||
import logging
|
||||
import insightface
|
||||
from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession
|
||||
from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession, get_default_providers
|
||||
from insightface.model_zoo.retinaface import RetinaFace
|
||||
from insightface.model_zoo.landmark import Landmark
|
||||
from insightface.model_zoo.attribute import Attribute
|
||||
@ -97,15 +97,20 @@ def patched_inswapper_init(self, model_file=None, session=None):
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
|
||||
|
||||
def patch_insightface(get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init):
|
||||
def patched_get_default_providers():
|
||||
return ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
|
||||
|
||||
def patch_insightface(get_default_providers, get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init):
|
||||
insightface.model_zoo.model_zoo.get_default_providers = get_default_providers
|
||||
insightface.model_zoo.model_zoo.ModelRouter.get_model = get_model
|
||||
insightface.app.FaceAnalysis.__init__ = faceanalysis_init
|
||||
insightface.app.FaceAnalysis.prepare = faceanalysis_prepare
|
||||
insightface.model_zoo.inswapper.INSwapper.__init__ = inswapper_init
|
||||
|
||||
|
||||
original_functions = [ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__]
|
||||
patched_functions = [patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init]
|
||||
original_functions = [patched_get_default_providers, ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__]
|
||||
patched_functions = [patched_get_default_providers, patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init]
|
||||
|
||||
|
||||
def apply_logging_patch(console_logging_level):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user