summaryrefslogtreecommitdiffstats
path: root/scripts/r_faceboost/restorer.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/r_faceboost/restorer.py')
-rw-r--r--scripts/r_faceboost/restorer.py130
1 files changed, 130 insertions, 0 deletions
diff --git a/scripts/r_faceboost/restorer.py b/scripts/r_faceboost/restorer.py
new file mode 100644
index 0000000..60c0165
--- /dev/null
+++ b/scripts/r_faceboost/restorer.py
@@ -0,0 +1,130 @@
+import sys
+import cv2
+import numpy as np
+import torch
+from torchvision.transforms.functional import normalize
+
+try:
+ import torch.cuda as cuda
+except:
+ cuda = None
+
+import comfy.utils
+import folder_paths
+import comfy.model_management as model_management
+
+from scripts.reactor_logger import logger
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from r_chainner import model_loading
+from reactor_utils import (
+ tensor2img,
+ img2tensor,
+ set_ort_session,
+ prepare_cropped_face,
+ normalize_cropped_face
+)
+
+
+if cuda is not None:
+ if cuda.is_available():
+ providers = ["CUDAExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+else:
+ providers = ["CPUExecutionProvider"]
+
+
+def get_restored_face(cropped_face,
+ face_restore_model,
+ face_restore_visibility,
+ codeformer_weight,
+ interpolation: str = "Bicubic"):
+
+ if interpolation == "Bicubic":
+ interpolate = cv2.INTER_CUBIC
+ elif interpolation == "Bilinear":
+ interpolate = cv2.INTER_LINEAR
+ elif interpolation == "Nearest":
+ interpolate = cv2.INTER_NEAREST
+ elif interpolation == "Lanczos":
+ interpolate = cv2.INTER_LANCZOS4
+
+ face_size = 512
+ if "1024" in face_restore_model.lower():
+ face_size = 1024
+ elif "2048" in face_restore_model.lower():
+ face_size = 2048
+
+ scale = face_size / cropped_face.shape[0]
+
+ logger.status(f"Boosting the Face with {face_restore_model} | Face Size is set to {face_size} with Scale Factor = {scale} and '{interpolation}' interpolation")
+
+ cropped_face = cv2.resize(cropped_face, (face_size, face_size), interpolation=interpolate)
+
+ # For upscaling the base 128px face, I found bicubic interpolation to be the best compromise targeting antialiasing
+ # and detail preservation. Nearest is predictably unusable, Linear produces too much aliasing, and Lanczos produces
+ # too many hallucinations and artifacts/fringing.
+
+ model_path = folder_paths.get_full_path("facerestore_models", face_restore_model)
+ device = model_management.get_torch_device()
+
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+ try:
+
+ with torch.no_grad():
+
+ if ".onnx" in face_restore_model: # ONNX models
+
+ ort_session = set_ort_session(model_path, providers=providers)
+ ort_session_inputs = {}
+ facerestore_model = ort_session
+
+ for ort_session_input in ort_session.get_inputs():
+ if ort_session_input.name == "input":
+ cropped_face_prep = prepare_cropped_face(cropped_face)
+ ort_session_inputs[ort_session_input.name] = cropped_face_prep
+ if ort_session_input.name == "weight":
+ weight = np.array([1], dtype=np.double)
+ ort_session_inputs[ort_session_input.name] = weight
+
+ output = ort_session.run(None, ort_session_inputs)[0][0]
+ restored_face = normalize_cropped_face(output)
+
+ else: # PTH models
+
+ if "codeformer" in face_restore_model.lower():
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
+ dim_embd=512,
+ codebook_size=1024,
+ n_head=8,
+ n_layers=9,
+ connect_list=["32", "64", "128", "256"],
+ ).to(device)
+ checkpoint = torch.load(model_path)["params_ema"]
+ codeformer_net.load_state_dict(checkpoint)
+ facerestore_model = codeformer_net.eval()
+ else:
+ sd = comfy.utils.load_torch_file(model_path, safe_load=True)
+ facerestore_model = model_loading.load_state_dict(sd).eval()
+ facerestore_model.to(device)
+
+ output = facerestore_model(cropped_face_t, w=codeformer_weight)[
+ 0] if "codeformer" in face_restore_model.lower() else facerestore_model(cropped_face_t)[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+
+ del output
+ torch.cuda.empty_cache()
+
+ except Exception as error:
+
+ print(f"\tFailed inference: {error}", file=sys.stderr)
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ if face_restore_visibility < 1:
+ restored_face = cropped_face * (1 - face_restore_visibility) + restored_face * face_restore_visibility
+
+ restored_face = restored_face.astype("uint8")
+ return restored_face, scale