summaryrefslogtreecommitdiffstats
path: root/reactor_patcher.py
diff options
context:
space:
mode:
authorGrafting Rayman <156515434+GraftingRayman@users.noreply.github.com>2025-01-17 11:06:44 +0000
committerGitHub <noreply@github.com>2025-01-17 11:06:44 +0000
commite6bd5af6a8e306a1cdef63402a77a980a04ad6e1 (patch)
treed0732226bbc22feedad9e834b2218d7d0b0eff54 /reactor_patcher.py
parent495ffc4777522e40941753e3b1b79c02f84b25b4 (diff)
downloadComfyui-reactor-node-main.tar.gz
Add files via uploadHEADmain
Diffstat (limited to 'reactor_patcher.py')
-rw-r--r--reactor_patcher.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/reactor_patcher.py b/reactor_patcher.py
new file mode 100644
index 0000000..1818def
--- /dev/null
+++ b/reactor_patcher.py
@@ -0,0 +1,135 @@
+import os.path as osp
+import glob
+import logging
+import insightface
+from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession
+from insightface.model_zoo.retinaface import RetinaFace
+from insightface.model_zoo.landmark import Landmark
+from insightface.model_zoo.attribute import Attribute
+from insightface.model_zoo.inswapper import INSwapper
+from insightface.model_zoo.arcface_onnx import ArcFaceONNX
+from insightface.app import FaceAnalysis
+from insightface.utils import DEFAULT_MP_NAME, ensure_available
+from insightface.model_zoo import model_zoo
+import onnxruntime
+import onnx
+from onnx import numpy_helper
+from scripts.reactor_logger import logger
+
+
+def patched_get_model(self, **kwargs):
+ session = PickableInferenceSession(self.onnx_file, **kwargs)
+ inputs = session.get_inputs()
+ input_cfg = inputs[0]
+ input_shape = input_cfg.shape
+ outputs = session.get_outputs()
+
+ if len(outputs) >= 5:
+ return RetinaFace(model_file=self.onnx_file, session=session)
+ elif input_shape[2] == 192 and input_shape[3] == 192:
+ return Landmark(model_file=self.onnx_file, session=session)
+ elif input_shape[2] == 96 and input_shape[3] == 96:
+ return Attribute(model_file=self.onnx_file, session=session)
+ elif len(inputs) == 2 and input_shape[2] == 128 and input_shape[3] == 128:
+ return INSwapper(model_file=self.onnx_file, session=session)
+ elif input_shape[2] == input_shape[3] and input_shape[2] >= 112 and input_shape[2] % 16 == 0:
+ return ArcFaceONNX(model_file=self.onnx_file, session=session)
+ else:
+ return None
+
+
+def patched_faceanalysis_init(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
+ onnxruntime.set_default_logger_severity(3)
+ self.models = {}
+ self.model_dir = ensure_available('models', name, root=root)
+ onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx'))
+ onnx_files = sorted(onnx_files)
+ for onnx_file in onnx_files:
+ model = model_zoo.get_model(onnx_file, **kwargs)
+ if model is None:
+ print('model not recognized:', onnx_file)
+ elif allowed_modules is not None and model.taskname not in allowed_modules:
+ print('model ignore:', onnx_file, model.taskname)
+ del model
+ elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules):
+ self.models[model.taskname] = model
+ else:
+ print('duplicated model task type, ignore:', onnx_file, model.taskname)
+ del model
+ assert 'detection' in self.models
+ self.det_model = self.models['detection']
+
+
+def patched_faceanalysis_prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
+ self.det_thresh = det_thresh
+ assert det_size is not None
+ self.det_size = det_size
+ for taskname, model in self.models.items():
+ if taskname == 'detection':
+ model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh)
+ else:
+ model.prepare(ctx_id)
+
+
+def patched_inswapper_init(self, model_file=None, session=None):
+ self.model_file = model_file
+ self.session = session
+ model = onnx.load(self.model_file)
+ graph = model.graph
+ self.emap = numpy_helper.to_array(graph.initializer[-1])
+ self.input_mean = 0.0
+ self.input_std = 255.0
+ if self.session is None:
+ self.session = onnxruntime.InferenceSession(self.model_file, None)
+ inputs = self.session.get_inputs()
+ self.input_names = []
+ for inp in inputs:
+ self.input_names.append(inp.name)
+ outputs = self.session.get_outputs()
+ output_names = []
+ for out in outputs:
+ output_names.append(out.name)
+ self.output_names = output_names
+ assert len(self.output_names) == 1
+ input_cfg = inputs[0]
+ input_shape = input_cfg.shape
+ self.input_shape = input_shape
+ self.input_size = tuple(input_shape[2:4][::-1])
+
+
+def pathced_retinaface_prepare(self, ctx_id, **kwargs):
+ if ctx_id<0:
+ self.session.set_providers(['CPUExecutionProvider'])
+ nms_thresh = kwargs.get('nms_thresh', None)
+ if nms_thresh is not None:
+ self.nms_thresh = nms_thresh
+ det_thresh = kwargs.get('det_thresh', None)
+ if det_thresh is not None:
+ self.det_thresh = det_thresh
+ input_size = kwargs.get('input_size', None)
+ if input_size is not None and self.input_size is None:
+ self.input_size = input_size
+
+
+def patch_insightface(get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init, retinaface_prepare):
+ insightface.model_zoo.model_zoo.ModelRouter.get_model = get_model
+ insightface.app.FaceAnalysis.__init__ = faceanalysis_init
+ insightface.app.FaceAnalysis.prepare = faceanalysis_prepare
+ insightface.model_zoo.inswapper.INSwapper.__init__ = inswapper_init
+ insightface.model_zoo.retinaface.RetinaFace.prepare = retinaface_prepare
+
+
+original_functions = [ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__, RetinaFace.prepare]
+patched_functions = [patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init, pathced_retinaface_prepare]
+
+
+def apply_patch(console_log_level):
+ if console_log_level == 0:
+ patch_insightface(*patched_functions)
+ logger.setLevel(logging.WARNING)
+ elif console_log_level == 1:
+ patch_insightface(*patched_functions)
+ logger.setLevel(logging.STATUS)
+ elif console_log_level == 2:
+ patch_insightface(*original_functions)
+ logger.setLevel(logging.INFO)