diff options
author | Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> | 2025-01-17 11:06:44 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 11:06:44 +0000 |
commit | e6bd5af6a8e306a1cdef63402a77a980a04ad6e1 (patch) | |
tree | d0732226bbc22feedad9e834b2218d7d0b0eff54 /r_facelib/utils | |
parent | 495ffc4777522e40941753e3b1b79c02f84b25b4 (diff) | |
download | Comfyui-reactor-node-e6bd5af6a8e306a1cdef63402a77a980a04ad6e1.tar.gz |
Diffstat (limited to 'r_facelib/utils')
-rw-r--r-- | r_facelib/utils/__init__.py | 7 | ||||
-rw-r--r-- | r_facelib/utils/face_restoration_helper.py | 455 | ||||
-rw-r--r-- | r_facelib/utils/face_utils.py | 248 | ||||
-rw-r--r-- | r_facelib/utils/misc.py | 143 |
4 files changed, 853 insertions, 0 deletions
diff --git a/r_facelib/utils/__init__.py b/r_facelib/utils/__init__.py new file mode 100644 index 0000000..3397bda --- /dev/null +++ b/r_facelib/utils/__init__.py @@ -0,0 +1,7 @@ +from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
+from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir
+
+__all__ = [
+ 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url',
+ 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir'
+]
diff --git a/r_facelib/utils/face_restoration_helper.py b/r_facelib/utils/face_restoration_helper.py new file mode 100644 index 0000000..1db75c9 --- /dev/null +++ b/r_facelib/utils/face_restoration_helper.py @@ -0,0 +1,455 @@ +import cv2
+import numpy as np
+import os
+import torch
+from torchvision.transforms.functional import normalize
+
+from r_facelib.detection import init_detection_model
+from r_facelib.parsing import init_parsing_model
+from r_facelib.utils.misc import img2tensor, imwrite
+
+
+def get_largest_face(det_faces, h, w):
+
+ def get_location(val, length):
+ if val < 0:
+ return 0
+ elif val > length:
+ return length
+ else:
+ return val
+
+ face_areas = []
+ for det_face in det_faces:
+ left = get_location(det_face[0], w)
+ right = get_location(det_face[2], w)
+ top = get_location(det_face[1], h)
+ bottom = get_location(det_face[3], h)
+ face_area = (right - left) * (bottom - top)
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+ if center is not None:
+ center = np.array(center)
+ else:
+ center = np.array([w / 2, h / 2])
+ center_dist = []
+ for det_face in det_faces:
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+ dist = np.linalg.norm(face_center - center)
+ center_dist.append(dist)
+ center_idx = center_dist.index(min(center_dist))
+ return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper(object):
+ """Helper for the face restoration pipeline (base class)."""
+
+ def __init__(self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = upscale_factor
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+
+ if self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+ [201.26117, 371.41043], [313.08905, 371.15118]])
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ else:
+ self.device = device
+
+ # init face detection model
+ self.face_det = init_detection_model(det_model, half=False, device=self.device)
+
+ # init face parsing model
+ self.use_parse = use_parse
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
+
+ def set_upscale_factor(self, upscale_factor):
+ self.upscale_factor = upscale_factor
+
+ def read_image(self, img):
+ """img can be image path or cv2 loaded image."""
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ if isinstance(img, str):
+ img = cv2.imread(img)
+
+ if np.max(img) > 256: # 16-bit image
+ img = img / 65535 * 255
+ if len(img.shape) == 2: # gray image
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif img.shape[2] == 4: # BGRA image with alpha channel
+ img = img[:, :, 0:3]
+
+ self.input_img = img
+
+ if min(self.input_img.shape[:2])<512:
+ f = 512.0/min(self.input_img.shape[:2])
+ self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+ def get_face_landmarks_5(self,
+ only_keep_largest=False,
+ only_center_face=False,
+ resize=None,
+ blur_ratio=0.01,
+ eye_dist_threshold=None):
+ if resize is None:
+ scale = 1
+ input_img = self.input_img
+ else:
+ h, w = self.input_img.shape[0:2]
+ scale = resize / min(h, w)
+ scale = max(1, scale) # always scale up
+ h, w = int(h * scale), int(w * scale)
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
+
+ with torch.no_grad():
+ bboxes = self.face_det.detect_faces(input_img)
+
+ if bboxes is None or bboxes.shape[0] == 0:
+ return 0
+ else:
+ bboxes = bboxes / scale
+
+ for bbox in bboxes:
+ # remove faces with too small eye distance: side faces or too small faces
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+ continue
+
+ if self.template_3points:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
+ else:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
+ self.all_landmarks_5.append(landmark)
+ self.det_faces.append(bbox[0:5])
+
+ if len(self.det_faces) == 0:
+ return 0
+ if only_keep_largest:
+ h, w, _ = self.input_img.shape
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+ elif only_center_face:
+ h, w, _ = self.input_img.shape
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+ # pad blurry images
+ if self.pad_blur:
+ self.pad_input_imgs = []
+ for landmarks in self.all_landmarks_5:
+ # get landmarks
+ eye_left = landmarks[0, :]
+ eye_right = landmarks[1, :]
+ eye_avg = (eye_left + eye_right) * 0.5
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1.5
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+ border = max(int(np.rint(qsize * 0.1)), 3)
+
+ # get pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = [
+ max(-pad[0] + border, 1),
+ max(-pad[1] + border, 1),
+ max(pad[2] - self.input_img.shape[0] + border, 1),
+ max(pad[3] - self.input_img.shape[1] + border, 1)
+ ]
+
+ if max(pad) > 1:
+ # pad image
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # modify landmark coords
+ landmarks[:, 0] += pad[0]
+ landmarks[:, 1] += pad[1]
+ # blur pad images
+ h, w, _ = pad_img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * blur_ratio)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+ pad_img = pad_img.astype('float32')
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
+ self.pad_input_imgs.append(pad_img)
+ else:
+ self.pad_input_imgs.append(np.copy(self.input_img))
+
+ return len(self.all_landmarks_5)
+
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
+ """Align and warp faces with face template.
+ """
+ if self.pad_blur:
+ assert len(self.pad_input_imgs) == len(
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
+ for idx, landmark in enumerate(self.all_landmarks_5):
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
+ self.affine_matrices.append(affine_matrix)
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+ if self.pad_blur:
+ input_img = self.pad_input_imgs[idx]
+ else:
+ input_img = self.input_img
+ cropped_face = cv2.warpAffine(
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+ self.cropped_faces.append(cropped_face)
+ # save the cropped face
+ if save_cropped_path is not None:
+ path = os.path.splitext(save_cropped_path)[0]
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
+ imwrite(cropped_face, save_path)
+
+ def get_inverse_affine(self, save_inverse_affine_path=None):
+ """Get inverse affine matrix."""
+ for idx, affine_matrix in enumerate(self.affine_matrices):
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ inverse_affine *= self.upscale_factor
+ self.inverse_affine_matrices.append(inverse_affine)
+ # save inverse affine matrices
+ if save_inverse_affine_path is not None:
+ path, _ = os.path.splitext(save_inverse_affine_path)
+ save_path = f'{path}_{idx:02d}.pth'
+ torch.save(inverse_affine, save_path)
+
+
+ def add_restored_face(self, face):
+ self.restored_faces.append(face)
+
+
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+ h, w, _ = self.input_img.shape
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+ if upsample_img is None:
+ # simply resize the background
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+ else:
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+ assert len(self.restored_faces) == len(
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
+
+ inv_mask_borders = []
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
+ if face_upsampler is not None:
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
+ inverse_affine /= self.upscale_factor
+ inverse_affine[:, 2] *= self.upscale_factor
+ face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
+ else:
+ # Add an offset to inverse affine matrix, for more precise back alignment
+ if self.upscale_factor > 1:
+ extra_offset = 0.5 * self.upscale_factor
+ else:
+ extra_offset = 0
+ inverse_affine[:, 2] += extra_offset
+ face_size = self.face_size
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
+
+ # if draw_box or not self.use_parse: # use square parse maps
+ # mask = np.ones(face_size, dtype=np.float32)
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # # remove the black borders
+ # inv_mask_erosion = cv2.erode(
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
+ # # add border
+ # if draw_box:
+ # h, w = face_size
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
+ # border = int(1400/np.sqrt(total_face_area))
+ # mask_border[border:h-border, border:w-border,:] = 0
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ # inv_mask_borders.append(inv_mask_border)
+ # if not self.use_parse:
+ # # compute the fusion edge based on the area of face
+ # w_edge = int(total_face_area**0.5) // 20
+ # erosion_radius = w_edge * 2
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ # blur_size = w_edge * 2
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
+ # upsample_img = upsample_img[:, :, None]
+ # inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # always use square mask
+ mask = np.ones(face_size, dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) # // 3
+ # add border
+ if draw_box:
+ h, w = face_size
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
+ border = int(1400/np.sqrt(total_face_area))
+ mask_border[border:h-border, border:w-border,:] = 0
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ inv_mask_borders.append(inv_mask_border)
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
+ upsample_img = upsample_img[:, :, None]
+ inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # parse mask
+ if self.use_parse:
+ # inference
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
+ with torch.no_grad():
+ out = self.face_parse(face_input)[0]
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+ parse_mask = np.zeros(out.shape)
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+ for idx, color in enumerate(MASK_COLORMAP):
+ parse_mask[out == idx] = color
+ # blur the mask
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ # remove the black borders
+ thres = 10
+ parse_mask[:thres, :] = 0
+ parse_mask[-thres:, :] = 0
+ parse_mask[:, :thres] = 0
+ parse_mask[:, -thres:] = 0
+ parse_mask = parse_mask / 255.
+
+ parse_mask = cv2.resize(parse_mask, face_size)
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
+ inv_soft_parse_mask = parse_mask[:, :, None]
+ # pasted_face = inv_restored
+ fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
+ inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
+
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
+ alpha = upsample_img[:, :, 3:]
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
+ else:
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
+
+ if np.max(upsample_img) > 256: # 16-bit image
+ upsample_img = upsample_img.astype(np.uint16)
+ else:
+ upsample_img = upsample_img.astype(np.uint8)
+
+ # draw bounding box
+ if draw_box:
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+ img_color[:,:,0] = 0
+ img_color[:,:,1] = 255
+ img_color[:,:,2] = 0
+ for inv_mask_border in inv_mask_borders:
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+ if save_path is not None:
+ path = os.path.splitext(save_path)[0]
+ save_path = f'{path}.{self.save_ext}'
+ imwrite(upsample_img, save_path)
+ return upsample_img
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
diff --git a/r_facelib/utils/face_utils.py b/r_facelib/utils/face_utils.py new file mode 100644 index 0000000..657ad25 --- /dev/null +++ b/r_facelib/utils/face_utils.py @@ -0,0 +1,248 @@ +import cv2
+import numpy as np
+import torch
+
+
+def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
+ left, top, right, bot = bbox
+ width = right - left
+ height = bot - top
+
+ if preserve_aspect:
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
+ else:
+ width_increase = height_increase = increase_area
+ left = int(left - width_increase * width)
+ top = int(top - height_increase * height)
+ right = int(right + width_increase * width)
+ bot = int(bot + height_increase * height)
+ return (left, top, right, bot)
+
+
+def get_valid_bboxes(bboxes, h, w):
+ left = max(bboxes[0], 0)
+ top = max(bboxes[1], 0)
+ right = min(bboxes[2], w)
+ bottom = min(bboxes[3], h)
+ return (left, top, right, bottom)
+
+
+def align_crop_face_landmarks(img,
+ landmarks,
+ output_size,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=False,
+ shrink_ratio=(1, 1)):
+ """Align and crop face with landmarks.
+
+ The output_size and transform_size are based on width. The height is
+ adjusted based on shrink_ratio_h/shring_ration_w.
+
+ Modified from:
+ https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
+
+ Args:
+ img (Numpy array): Input image.
+ landmarks (Numpy array): 5 or 68 or 98 landmarks.
+ output_size (int): Output face size.
+ transform_size (ing): Transform size. Usually the four time of
+ output_size.
+ enable_padding (float): Default: True.
+ shrink_ratio (float | tuple[float] | list[float]): Shring the whole
+ face for height and width (crop larger area). Default: (1, 1).
+
+ Returns:
+ (Numpy array): Cropped face.
+ """
+ lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
+
+ if isinstance(shrink_ratio, (float, int)):
+ shrink_ratio = (shrink_ratio, shrink_ratio)
+ if transform_size is None:
+ transform_size = output_size * 4
+
+ # Parse landmarks
+ lm = np.array(landmarks)
+ if lm.shape[0] == 5 and lm_type == 'retinaface_5':
+ eye_left = lm[0]
+ eye_right = lm[1]
+ mouth_avg = (lm[3] + lm[4]) * 0.5
+ elif lm.shape[0] == 5 and lm_type == 'dlib_5':
+ lm_eye_left = lm[2:4]
+ lm_eye_right = lm[0:2]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = lm[4]
+ elif lm.shape[0] == 68:
+ lm_eye_left = lm[36:42]
+ lm_eye_right = lm[42:48]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[48] + lm[54]) * 0.5
+ elif lm.shape[0] == 98:
+ lm_eye_left = lm[60:68]
+ lm_eye_right = lm[68:76]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[76] + lm[82]) * 0.5
+
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1 # TODO: you can edit it to get larger rect
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ x *= shrink_ratio[1] # width
+ y *= shrink_ratio[0] # height
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+
+ quad_ori = np.copy(quad)
+ # Shrink, for large face
+ # TODO: do we really need shrink
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ h, w = img.shape[0:2]
+ rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
+ img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop
+ h, w = img.shape[0:2]
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
+ if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
+ img = img[crop[1]:crop[3], crop[0]:crop[2], :]
+ quad -= crop[0:2]
+
+ # Pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ h, w = img.shape[0:2]
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w = img.shape[0:2]
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * 0.02)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
+
+ img = img.astype('float32')
+ img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = np.clip(img, 0, 255) # float32, [0, 255]
+ quad += pad[:2]
+
+ # Transform use cv2
+ h_ratio = shrink_ratio[0] / shrink_ratio[1]
+ dst_h, dst_w = int(transform_size * h_ratio), transform_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
+ cropped_face = cv2.warpAffine(
+ img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
+
+ if output_size < transform_size:
+ cropped_face = cv2.resize(
+ cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
+
+ if return_inverse_affine:
+ dst_h, dst_w = int(output_size * h_ratio), output_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(
+ quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ else:
+ inverse_affine = None
+ return cropped_face, inverse_affine
+
+
+def paste_face_back(img, face, inverse_affine):
+ h, w = img.shape[0:2]
+ face_h, face_w = face.shape[0:2]
+ inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
+ mask = np.ones((face_h, face_w, 3), dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) // 3
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
+ # float32, [0, 255]
+ return img
+
+
+if __name__ == '__main__':
+ import os
+
+ from custom_nodes.facerestore.facelib.detection import init_detection_model
+ from custom_nodes.facerestore.facelib.utils.face_restoration_helper import get_largest_face
+
+ img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
+ img_name = os.splitext(os.path.basename(img_path))[0]
+
+ # initialize model
+ det_net = init_detection_model('retinaface_resnet50', half=False)
+ img_ori = cv2.imread(img_path)
+ h, w = img_ori.shape[0:2]
+ # if larger than 800, scale it
+ scale = max(h / 800, w / 800)
+ if scale > 1:
+ img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
+
+ with torch.no_grad():
+ bboxes = det_net.detect_faces(img, 0.97)
+ if scale > 1:
+ bboxes *= scale # the score is incorrect
+ bboxes = get_largest_face(bboxes, h, w)[0]
+
+ landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
+
+ cropped_face, inverse_affine = align_crop_face_landmarks(
+ img_ori,
+ landmarks,
+ output_size=512,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=True,
+ shrink_ratio=(1, 1))
+
+ cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
+ img = paste_face_back(img_ori, cropped_face, inverse_affine)
+ cv2.imwrite(f'tmp/{img_name}_back.png', img)
diff --git a/r_facelib/utils/misc.py b/r_facelib/utils/misc.py new file mode 100644 index 0000000..6ea7c65 --- /dev/null +++ b/r_facelib/utils/misc.py @@ -0,0 +1,143 @@ +import cv2
+import os
+import os.path as osp
+import torch
+from torch.hub import download_url_to_file, get_dir
+from urllib.parse import urlparse
+# from basicsr.utils.download_util import download_file_from_google_drive
+#import gdown
+
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+def download_pretrained_models(file_ids, save_path_root):
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_id in file_ids.items():
+ file_url = 'https://drive.google.com/uc?id='+file_id
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
+ if osp.exists(save_path):
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+ if user_response.lower() == 'y':
+ print(f'Covering {file_name} to {save_path}')
+ print("skipping gdown in facelib/utils/misc.py "+file_url)
+ #gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+ elif user_response.lower() == 'n':
+ print(f'Skipping {file_name}')
+ else:
+ raise ValueError('Wrong input. Only accepts Y/N.')
+ else:
+ print(f'Downloading {file_name} to {save_path}')
+ print("skipping gdown in facelib/utils/misc.py "+file_url)
+ #gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ """
+ if model_dir is None:
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
|