From e6bd5af6a8e306a1cdef63402a77a980a04ad6e1 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:06:44 +0000 Subject: Add files via upload --- reactor_patcher.py | 135 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 reactor_patcher.py (limited to 'reactor_patcher.py') diff --git a/reactor_patcher.py b/reactor_patcher.py new file mode 100644 index 0000000..1818def --- /dev/null +++ b/reactor_patcher.py @@ -0,0 +1,135 @@ +import os.path as osp +import glob +import logging +import insightface +from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession +from insightface.model_zoo.retinaface import RetinaFace +from insightface.model_zoo.landmark import Landmark +from insightface.model_zoo.attribute import Attribute +from insightface.model_zoo.inswapper import INSwapper +from insightface.model_zoo.arcface_onnx import ArcFaceONNX +from insightface.app import FaceAnalysis +from insightface.utils import DEFAULT_MP_NAME, ensure_available +from insightface.model_zoo import model_zoo +import onnxruntime +import onnx +from onnx import numpy_helper +from scripts.reactor_logger import logger + + +def patched_get_model(self, **kwargs): + session = PickableInferenceSession(self.onnx_file, **kwargs) + inputs = session.get_inputs() + input_cfg = inputs[0] + input_shape = input_cfg.shape + outputs = session.get_outputs() + + if len(outputs) >= 5: + return RetinaFace(model_file=self.onnx_file, session=session) + elif input_shape[2] == 192 and input_shape[3] == 192: + return Landmark(model_file=self.onnx_file, session=session) + elif input_shape[2] == 96 and input_shape[3] == 96: + return Attribute(model_file=self.onnx_file, session=session) + elif len(inputs) == 2 and input_shape[2] == 128 and input_shape[3] == 128: + return INSwapper(model_file=self.onnx_file, session=session) + elif input_shape[2] == input_shape[3] and input_shape[2] >= 112 and input_shape[2] % 16 == 0: + return ArcFaceONNX(model_file=self.onnx_file, session=session) + else: + return None + + +def patched_faceanalysis_init(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs): + onnxruntime.set_default_logger_severity(3) + self.models = {} + self.model_dir = ensure_available('models', name, root=root) + onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx')) + onnx_files = sorted(onnx_files) + for onnx_file in onnx_files: + model = model_zoo.get_model(onnx_file, **kwargs) + if model is None: + print('model not recognized:', onnx_file) + elif allowed_modules is not None and model.taskname not in allowed_modules: + print('model ignore:', onnx_file, model.taskname) + del model + elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules): + self.models[model.taskname] = model + else: + print('duplicated model task type, ignore:', onnx_file, model.taskname) + del model + assert 'detection' in self.models + self.det_model = self.models['detection'] + + +def patched_faceanalysis_prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): + self.det_thresh = det_thresh + assert det_size is not None + self.det_size = det_size + for taskname, model in self.models.items(): + if taskname == 'detection': + model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh) + else: + model.prepare(ctx_id) + + +def patched_inswapper_init(self, model_file=None, session=None): + self.model_file = model_file + self.session = session + model = onnx.load(self.model_file) + graph = model.graph + self.emap = numpy_helper.to_array(graph.initializer[-1]) + self.input_mean = 0.0 + self.input_std = 255.0 + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + inputs = self.session.get_inputs() + self.input_names = [] + for inp in inputs: + self.input_names.append(inp.name) + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.output_names = output_names + assert len(self.output_names) == 1 + input_cfg = inputs[0] + input_shape = input_cfg.shape + self.input_shape = input_shape + self.input_size = tuple(input_shape[2:4][::-1]) + + +def pathced_retinaface_prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + nms_thresh = kwargs.get('nms_thresh', None) + if nms_thresh is not None: + self.nms_thresh = nms_thresh + det_thresh = kwargs.get('det_thresh', None) + if det_thresh is not None: + self.det_thresh = det_thresh + input_size = kwargs.get('input_size', None) + if input_size is not None and self.input_size is None: + self.input_size = input_size + + +def patch_insightface(get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init, retinaface_prepare): + insightface.model_zoo.model_zoo.ModelRouter.get_model = get_model + insightface.app.FaceAnalysis.__init__ = faceanalysis_init + insightface.app.FaceAnalysis.prepare = faceanalysis_prepare + insightface.model_zoo.inswapper.INSwapper.__init__ = inswapper_init + insightface.model_zoo.retinaface.RetinaFace.prepare = retinaface_prepare + + +original_functions = [ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__, RetinaFace.prepare] +patched_functions = [patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init, pathced_retinaface_prepare] + + +def apply_patch(console_log_level): + if console_log_level == 0: + patch_insightface(*patched_functions) + logger.setLevel(logging.WARNING) + elif console_log_level == 1: + patch_insightface(*patched_functions) + logger.setLevel(logging.STATUS) + elif console_log_level == 2: + patch_insightface(*original_functions) + logger.setLevel(logging.INFO) -- cgit v1.2.3