From e6bd5af6a8e306a1cdef63402a77a980a04ad6e1 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:06:44 +0000 Subject: Add files via upload --- scripts/__init__.py | 0 scripts/r_archs/__init__.py | 0 scripts/r_archs/codeformer_arch.py | 278 ++++++++++++++++ scripts/r_archs/vqgan_arch.py | 437 +++++++++++++++++++++++++ scripts/r_faceboost/__init__.py | 0 scripts/r_faceboost/restorer.py | 130 ++++++++ scripts/r_faceboost/swapper.py | 42 +++ scripts/r_masking/__init__.py | 0 scripts/r_masking/core.py | 647 +++++++++++++++++++++++++++++++++++++ scripts/r_masking/segs.py | 22 ++ scripts/r_masking/subcore.py | 117 +++++++ scripts/reactor_faceswap.py | 185 +++++++++++ scripts/reactor_logger.py | 47 +++ scripts/reactor_swapper.py | 572 ++++++++++++++++++++++++++++++++ scripts/reactor_version.py | 13 + 15 files changed, 2490 insertions(+) create mode 100644 scripts/__init__.py create mode 100644 scripts/r_archs/__init__.py create mode 100644 scripts/r_archs/codeformer_arch.py create mode 100644 scripts/r_archs/vqgan_arch.py create mode 100644 scripts/r_faceboost/__init__.py create mode 100644 scripts/r_faceboost/restorer.py create mode 100644 scripts/r_faceboost/swapper.py create mode 100644 scripts/r_masking/__init__.py create mode 100644 scripts/r_masking/core.py create mode 100644 scripts/r_masking/segs.py create mode 100644 scripts/r_masking/subcore.py create mode 100644 scripts/reactor_faceswap.py create mode 100644 scripts/reactor_logger.py create mode 100644 scripts/reactor_swapper.py create mode 100644 scripts/reactor_version.py (limited to 'scripts') diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/r_archs/__init__.py b/scripts/r_archs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/r_archs/codeformer_arch.py b/scripts/r_archs/codeformer_arch.py new file mode 100644 index 0000000..588ef69 --- /dev/null +++ b/scripts/r_archs/codeformer_arch.py @@ -0,0 +1,278 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional, List + +from scripts.r_archs.vqgan_arch import * +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import ARCH_REGISTRY + + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + return tgt + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2*in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + residual = w * (dec_feat * scale + shift) + out = dec_feat + residual + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__(self, dim_embd=512, n_head=8, n_layers=9, + codebook_size=1024, latent_size=256, + connect_list=['32', '64', '128', '256'], + fix_modules=['quantize','generator']): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + + if fix_modules is not None: + for module in fix_modules: + for param in getattr(self, module).parameters(): + param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd*2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers)]) + + # logits_predict head + self.idx_pred_layer = nn.Sequential( + nn.LayerNorm(dim_embd), + nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + '16': 512, + '32': 256, + '64': 256, + '128': 128, + '256': 128, + '512': 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + # ################# Quantization ################### + # if self.training: + # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) + # # b(hw)c -> bc(hw) -> bchw + # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) + # ------------ + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + # preserve gradients + # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat + \ No newline at end of file diff --git a/scripts/r_archs/vqgan_arch.py b/scripts/r_archs/vqgan_arch.py new file mode 100644 index 0000000..50b3712 --- /dev/null +++ b/scripts/r_archs/vqgan_arch.py @@ -0,0 +1,437 @@ +''' +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from r_basicsr.utils import get_root_logger +from r_basicsr.utils.registry import ARCH_REGISTRY + + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x*torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ + 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + + mean_distance = torch.mean(d) + # find closest encodings + # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + min_encoding_scores = torch.exp(-min_encoding_scores/10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "min_encoding_scores": min_encoding_scores, + "mean_distance": mean_distance + } + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return z_q, diff, { + "min_encoding_indices": min_encoding_indices + } + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,)+tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions-1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + if self.quantizer_type == "nearest": + self.beta = beta #0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_ema' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) + logger.info(f'vqgan is loaded from: {model_path} [params_ema]') + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + logger.info(f'vqgan is loaded from: {model_path} [params]') + else: + raise ValueError(f'Wrong params!') + + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + layers += [ + nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_d' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + else: + raise ValueError(f'Wrong params!') + + def forward(self, x): + return self.main(x) + \ No newline at end of file diff --git a/scripts/r_faceboost/__init__.py b/scripts/r_faceboost/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/r_faceboost/restorer.py b/scripts/r_faceboost/restorer.py new file mode 100644 index 0000000..60c0165 --- /dev/null +++ b/scripts/r_faceboost/restorer.py @@ -0,0 +1,130 @@ +import sys +import cv2 +import numpy as np +import torch +from torchvision.transforms.functional import normalize + +try: + import torch.cuda as cuda +except: + cuda = None + +import comfy.utils +import folder_paths +import comfy.model_management as model_management + +from scripts.reactor_logger import logger +from r_basicsr.utils.registry import ARCH_REGISTRY +from r_chainner import model_loading +from reactor_utils import ( + tensor2img, + img2tensor, + set_ort_session, + prepare_cropped_face, + normalize_cropped_face +) + + +if cuda is not None: + if cuda.is_available(): + providers = ["CUDAExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] +else: + providers = ["CPUExecutionProvider"] + + +def get_restored_face(cropped_face, + face_restore_model, + face_restore_visibility, + codeformer_weight, + interpolation: str = "Bicubic"): + + if interpolation == "Bicubic": + interpolate = cv2.INTER_CUBIC + elif interpolation == "Bilinear": + interpolate = cv2.INTER_LINEAR + elif interpolation == "Nearest": + interpolate = cv2.INTER_NEAREST + elif interpolation == "Lanczos": + interpolate = cv2.INTER_LANCZOS4 + + face_size = 512 + if "1024" in face_restore_model.lower(): + face_size = 1024 + elif "2048" in face_restore_model.lower(): + face_size = 2048 + + scale = face_size / cropped_face.shape[0] + + logger.status(f"Boosting the Face with {face_restore_model} | Face Size is set to {face_size} with Scale Factor = {scale} and '{interpolation}' interpolation") + + cropped_face = cv2.resize(cropped_face, (face_size, face_size), interpolation=interpolate) + + # For upscaling the base 128px face, I found bicubic interpolation to be the best compromise targeting antialiasing + # and detail preservation. Nearest is predictably unusable, Linear produces too much aliasing, and Lanczos produces + # too many hallucinations and artifacts/fringing. + + model_path = folder_paths.get_full_path("facerestore_models", face_restore_model) + device = model_management.get_torch_device() + + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + + with torch.no_grad(): + + if ".onnx" in face_restore_model: # ONNX models + + ort_session = set_ort_session(model_path, providers=providers) + ort_session_inputs = {} + facerestore_model = ort_session + + for ort_session_input in ort_session.get_inputs(): + if ort_session_input.name == "input": + cropped_face_prep = prepare_cropped_face(cropped_face) + ort_session_inputs[ort_session_input.name] = cropped_face_prep + if ort_session_input.name == "weight": + weight = np.array([1], dtype=np.double) + ort_session_inputs[ort_session_input.name] = weight + + output = ort_session.run(None, ort_session_inputs)[0][0] + restored_face = normalize_cropped_face(output) + + else: # PTH models + + if "codeformer" in face_restore_model.lower(): + codeformer_net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], + ).to(device) + checkpoint = torch.load(model_path)["params_ema"] + codeformer_net.load_state_dict(checkpoint) + facerestore_model = codeformer_net.eval() + else: + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + facerestore_model = model_loading.load_state_dict(sd).eval() + facerestore_model.to(device) + + output = facerestore_model(cropped_face_t, w=codeformer_weight)[ + 0] if "codeformer" in face_restore_model.lower() else facerestore_model(cropped_face_t)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + + del output + torch.cuda.empty_cache() + + except Exception as error: + + print(f"\tFailed inference: {error}", file=sys.stderr) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + if face_restore_visibility < 1: + restored_face = cropped_face * (1 - face_restore_visibility) + restored_face * face_restore_visibility + + restored_face = restored_face.astype("uint8") + return restored_face, scale diff --git a/scripts/r_faceboost/swapper.py b/scripts/r_faceboost/swapper.py new file mode 100644 index 0000000..e0467cf --- /dev/null +++ b/scripts/r_faceboost/swapper.py @@ -0,0 +1,42 @@ +import cv2 +import numpy as np + +# The following code is almost entirely copied from INSwapper; the only change here is that we want to use Lanczos +# interpolation for the warpAffine call. Now that the face has been restored, Lanczos represents a good compromise +# whether the restored face needs to be upscaled or downscaled. +def in_swap(img, bgr_fake, M): + target_img = img + IM = cv2.invertAffineTransform(M) + img_white = np.full((bgr_fake.shape[0], bgr_fake.shape[1]), 255, dtype=np.float32) + + # Note the use of bicubic here; this is functionally the only change from the source code + bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0, flags=cv2.INTER_CUBIC) + + img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + img_white[img_white > 20] = 255 + img_mask = img_white + mask_h_inds, mask_w_inds = np.where(img_mask == 255) + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) + mask_size = int(np.sqrt(mask_h * mask_w)) + k = max(mask_size // 10, 10) + # k = max(mask_size//20, 6) + # k = 6 + kernel = np.ones((k, k), np.uint8) + img_mask = cv2.erode(img_mask, kernel, iterations=1) + kernel = np.ones((2, 2), np.uint8) + k = max(mask_size // 20, 5) + # k = 3 + # k = 3 + kernel_size = (k, k) + blur_size = tuple(2 * i + 1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + k = 5 + kernel_size = (k, k) + blur_size = tuple(2 * i + 1 for i in kernel_size) + img_mask /= 255 + # img_mask = fake_diff + img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1]) + fake_merged = img_mask * bgr_fake + (1 - img_mask) * target_img.astype(np.float32) + fake_merged = fake_merged.astype(np.uint8) + return fake_merged diff --git a/scripts/r_masking/__init__.py b/scripts/r_masking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/r_masking/core.py b/scripts/r_masking/core.py new file mode 100644 index 0000000..36862e1 --- /dev/null +++ b/scripts/r_masking/core.py @@ -0,0 +1,647 @@ +import numpy as np +import cv2 +import torch +import torchvision.transforms.functional as TF + +import sys as _sys +from keyword import iskeyword as _iskeyword +from operator import itemgetter as _itemgetter + +from segment_anything import SamPredictor + +from comfy import model_management + + +################################################################################ +### namedtuple +################################################################################ + +try: + from _collections import _tuplegetter +except ImportError: + _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc) + +def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): + """Returns a new subclass of tuple with named fields. + + >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point.__doc__ # docstring for the new class + 'Point(x, y)' + >>> p = Point(11, y=22) # instantiate with positional args or keywords + >>> p[0] + p[1] # indexable like a plain tuple + 33 + >>> x, y = p # unpack like a regular tuple + >>> x, y + (11, 22) + >>> p.x + p.y # fields also accessible by name + 33 + >>> d = p._asdict() # convert to a dictionary + >>> d['x'] + 11 + >>> Point(**d) # convert from a dictionary + Point(x=11, y=22) + >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields + Point(x=100, y=22) + + """ + + # Validate the field names. At the user's option, either generate an error + # message or automatically replace the field name with a valid name. + if isinstance(field_names, str): + field_names = field_names.replace(',', ' ').split() + field_names = list(map(str, field_names)) + typename = _sys.intern(str(typename)) + + if rename: + seen = set() + for index, name in enumerate(field_names): + if (not name.isidentifier() + or _iskeyword(name) + or name.startswith('_') + or name in seen): + field_names[index] = f'_{index}' + seen.add(name) + + for name in [typename] + field_names: + if type(name) is not str: + raise TypeError('Type names and field names must be strings') + if not name.isidentifier(): + raise ValueError('Type names and field names must be valid ' + f'identifiers: {name!r}') + if _iskeyword(name): + raise ValueError('Type names and field names cannot be a ' + f'keyword: {name!r}') + + seen = set() + for name in field_names: + if name.startswith('_') and not rename: + raise ValueError('Field names cannot start with an underscore: ' + f'{name!r}') + if name in seen: + raise ValueError(f'Encountered duplicate field name: {name!r}') + seen.add(name) + + field_defaults = {} + if defaults is not None: + defaults = tuple(defaults) + if len(defaults) > len(field_names): + raise TypeError('Got more default values than field names') + field_defaults = dict(reversed(list(zip(reversed(field_names), + reversed(defaults))))) + + # Variables used in the methods and docstrings + field_names = tuple(map(_sys.intern, field_names)) + num_fields = len(field_names) + arg_list = ', '.join(field_names) + if num_fields == 1: + arg_list += ',' + repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' + tuple_new = tuple.__new__ + _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip + + # Create all the named tuple methods to be added to the class namespace + + namespace = { + '_tuple_new': tuple_new, + '__builtins__': {}, + '__name__': f'namedtuple_{typename}', + } + code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' + __new__ = eval(code, namespace) + __new__.__name__ = '__new__' + __new__.__doc__ = f'Create new instance of {typename}({arg_list})' + if defaults is not None: + __new__.__defaults__ = defaults + + @classmethod + def _make(cls, iterable): + result = tuple_new(cls, iterable) + if _len(result) != num_fields: + raise TypeError(f'Expected {num_fields} arguments, got {len(result)}') + return result + + _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' + 'or iterable') + + def _replace(self, /, **kwds): + result = self._make(_map(kwds.pop, field_names, self)) + if kwds: + raise ValueError(f'Got unexpected field names: {list(kwds)!r}') + return result + + _replace.__doc__ = (f'Return a new {typename} object replacing specified ' + 'fields with new values') + + def __repr__(self): + 'Return a nicely formatted representation string' + return self.__class__.__name__ + repr_fmt % self + + def _asdict(self): + 'Return a new dict which maps field names to their values.' + return _dict(_zip(self._fields, self)) + + def __getnewargs__(self): + 'Return self as a plain tuple. Used by copy and pickle.' + return _tuple(self) + + # Modify function metadata to help with introspection and debugging + for method in ( + __new__, + _make.__func__, + _replace, + __repr__, + _asdict, + __getnewargs__, + ): + method.__qualname__ = f'{typename}.{method.__name__}' + + # Build-up the class namespace dictionary + # and use type() to build the result class + class_namespace = { + '__doc__': f'{typename}({arg_list})', + '__slots__': (), + '_fields': field_names, + '_field_defaults': field_defaults, + '__new__': __new__, + '_make': _make, + '_replace': _replace, + '__repr__': __repr__, + '_asdict': _asdict, + '__getnewargs__': __getnewargs__, + '__match_args__': field_names, + } + for index, name in enumerate(field_names): + doc = _sys.intern(f'Alias for field number {index}') + class_namespace[name] = _tuplegetter(index, doc) + + result = type(typename, (tuple,), class_namespace) + + # For pickling to work, the __module__ variable needs to be set to the frame + # where the named tuple is created. Bypass this step in environments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython), or where the user has + # specified a particular module. + if module is None: + try: + module = _sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + if module is not None: + result.__module__ = module + + return result + + +SEG = namedtuple("SEG", + ['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'], + defaults=[None]) + +def crop_ndarray4(npimg, crop_region): + x1 = crop_region[0] + y1 = crop_region[1] + x2 = crop_region[2] + y2 = crop_region[3] + + cropped = npimg[:, y1:y2, x1:x2, :] + + return cropped + +crop_tensor4 = crop_ndarray4 + +def crop_ndarray2(npimg, crop_region): + x1 = crop_region[0] + y1 = crop_region[1] + x2 = crop_region[2] + y2 = crop_region[3] + + cropped = npimg[y1:y2, x1:x2] + + return cropped + +def crop_image(image, crop_region): + return crop_tensor4(image, crop_region) + +def normalize_region(limit, startp, size): + if startp < 0: + new_endp = min(limit, size) + new_startp = 0 + elif startp + size > limit: + new_startp = max(0, limit - size) + new_endp = limit + else: + new_startp = startp + new_endp = min(limit, startp+size) + + return int(new_startp), int(new_endp) + +def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None): + x1 = bbox[0] + y1 = bbox[1] + x2 = bbox[2] + y2 = bbox[3] + + bbox_w = x2 - x1 + bbox_h = y2 - y1 + + crop_w = bbox_w * crop_factor + crop_h = bbox_h * crop_factor + + if crop_min_size is not None: + crop_w = max(crop_min_size, crop_w) + crop_h = max(crop_min_size, crop_h) + + kernel_x = x1 + bbox_w / 2 + kernel_y = y1 + bbox_h / 2 + + new_x1 = int(kernel_x - crop_w / 2) + new_y1 = int(kernel_y - crop_h / 2) + + # make sure position in (w,h) + new_x1, new_x2 = normalize_region(w, new_x1, crop_w) + new_y1, new_y2 = normalize_region(h, new_y1, crop_h) + + return [new_x1, new_y1, new_x2, new_y2] + +def create_segmasks(results): + bboxs = results[1] + segms = results[2] + confidence = results[3] + + results = [] + for i in range(len(segms)): + item = (bboxs[i], segms[i].astype(np.float32), confidence[i]) + results.append(item) + return results + +def dilate_masks(segmasks, dilation_factor, iter=1): + if dilation_factor == 0: + return segmasks + + dilated_masks = [] + kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) + + kernel = cv2.UMat(kernel) + + for i in range(len(segmasks)): + cv2_mask = segmasks[i][1] + + cv2_mask = cv2.UMat(cv2_mask) + + if dilation_factor > 0: + dilated_mask = cv2.dilate(cv2_mask, kernel, iter) + else: + dilated_mask = cv2.erode(cv2_mask, kernel, iter) + + dilated_mask = dilated_mask.get() + + item = (segmasks[i][0], dilated_mask, segmasks[i][2]) + dilated_masks.append(item) + + return dilated_masks + +def is_same_device(a, b): + a_device = torch.device(a) if isinstance(a, str) else a + b_device = torch.device(b) if isinstance(b, str) else b + return a_device.type == b_device.type and a_device.index == b_device.index + +class SafeToGPU: + def __init__(self, size): + self.size = size + + def to_device(self, obj, device): + if is_same_device(device, 'cpu'): + obj.to(device) + else: + if is_same_device(obj.device, 'cpu'): # cpu to gpu + model_management.free_memory(self.size * 1.3, device) + if model_management.get_free_memory(device) > self.size * 1.3: + try: + obj.to(device) + except: + print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]") + else: + print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]") + +def center_of_bbox(bbox): + w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] + return bbox[0] + w/2, bbox[1] + h/2 + +def sam_predict(predictor, points, plabs, bbox, threshold): + point_coords = None if not points else np.array(points) + point_labels = None if not plabs else np.array(plabs) + + box = np.array([bbox]) if bbox is not None else None + + cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) + + total_masks = [] + + selected = False + max_score = 0 + max_mask = None + for idx in range(len(scores)): + if scores[idx] > max_score: + max_score = scores[idx] + max_mask = cur_masks[idx] + + if scores[idx] >= threshold: + selected = True + total_masks.append(cur_masks[idx]) + else: + pass + + if not selected and max_mask is not None: + total_masks.append(max_mask) + + return total_masks + +def make_2d_mask(mask): + if len(mask.shape) == 4: + return mask.squeeze(0).squeeze(0) + + elif len(mask.shape) == 3: + return mask.squeeze(0) + + return mask + +def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative): + mask = make_2d_mask(mask) + + points = [] + plabs = [] + + # minimum sampling step >= 3 + y_step = max(3, int(mask.shape[0] / 20)) + x_step = max(3, int(mask.shape[1] / 20)) + + for i in range(0, len(mask), y_step): + for j in range(0, len(mask[i]), x_step): + if mask[i][j] > threshold: + points.append((x + j, y + i)) + plabs.append(1) + elif use_negative and mask[i][j] == 0: + points.append((x + j, y + i)) + plabs.append(0) + + return points, plabs + +def gen_negative_hints(w, h, x1, y1, x2, y2): + npoints = [] + nplabs = [] + + # minimum sampling step >= 3 + y_step = max(3, int(w / 20)) + x_step = max(3, int(h / 20)) + + for i in range(10, h - 10, y_step): + for j in range(10, w - 10, x_step): + if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10): + npoints.append((j, i)) + nplabs.append(0) + + return npoints, nplabs + +def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative, + mask_hint_use_negative): + [x1, y1, x2, y2] = dilated_bbox + + points = [] + plabs = [] + if detection_hint == "center-1": + points.append(center) + plabs = [1] # 1 = foreground point, 0 = background point + + elif detection_hint == "horizontal-2": + gap = (x2 - x1) / 3 + points.append((x1 + gap, center[1])) + points.append((x1 + gap * 2, center[1])) + plabs = [1, 1] + + elif detection_hint == "vertical-2": + gap = (y2 - y1) / 3 + points.append((center[0], y1 + gap)) + points.append((center[0], y1 + gap * 2)) + plabs = [1, 1] + + elif detection_hint == "rect-4": + x_gap = (x2 - x1) / 3 + y_gap = (y2 - y1) / 3 + points.append((x1 + x_gap, center[1])) + points.append((x1 + x_gap * 2, center[1])) + points.append((center[0], y1 + y_gap)) + points.append((center[0], y1 + y_gap * 2)) + plabs = [1, 1, 1, 1] + + elif detection_hint == "diamond-4": + x_gap = (x2 - x1) / 3 + y_gap = (y2 - y1) / 3 + points.append((x1 + x_gap, y1 + y_gap)) + points.append((x1 + x_gap * 2, y1 + y_gap)) + points.append((x1 + x_gap, y1 + y_gap * 2)) + points.append((x1 + x_gap * 2, y1 + y_gap * 2)) + plabs = [1, 1, 1, 1] + + elif detection_hint == "mask-point-bbox": + center = center_of_bbox(seg.bbox) + points.append(center) + plabs = [1] + + elif detection_hint == "mask-area": + points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1], + seg.cropped_mask, + mask_hint_threshold, use_small_negative) + + if mask_hint_use_negative == "Outter": + npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1], + seg.crop_region[0], seg.crop_region[1], + seg.crop_region[2], seg.crop_region[3]) + + points += npoints + plabs += nplabs + + return points, plabs + +def combine_masks2(masks): + if len(masks) == 0: + return None + else: + initial_cv2_mask = np.array(masks[0]).astype(np.uint8) + combined_cv2_mask = initial_cv2_mask + + for i in range(1, len(masks)): + cv2_mask = np.array(masks[i]).astype(np.uint8) + + if combined_cv2_mask.shape == cv2_mask.shape: + combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask) + else: + # do nothing - incompatible mask + pass + + mask = torch.from_numpy(combined_cv2_mask) + return mask + +def dilate_mask(mask, dilation_factor, iter=1): + if dilation_factor == 0: + return make_2d_mask(mask) + + mask = make_2d_mask(mask) + + kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) + + mask = cv2.UMat(mask) + kernel = cv2.UMat(kernel) + + if dilation_factor > 0: + result = cv2.dilate(mask, kernel, iter) + else: + result = cv2.erode(mask, kernel, iter) + + return result.get() + +def convert_and_stack_masks(masks): + if len(masks) == 0: + return None + + mask_tensors = [] + for mask in masks: + mask_array = np.array(mask, dtype=np.uint8) + mask_tensor = torch.from_numpy(mask_array) + mask_tensors.append(mask_tensor) + + stacked_masks = torch.stack(mask_tensors, dim=0) + stacked_masks = stacked_masks.unsqueeze(1) + + return stacked_masks + +def merge_and_stack_masks(stacked_masks, group_size): + if stacked_masks is None: + return None + + num_masks = stacked_masks.size(0) + merged_masks = [] + + for i in range(0, num_masks, group_size): + subset_masks = stacked_masks[i:i + group_size] + merged_mask = torch.any(subset_masks, dim=0) + merged_masks.append(merged_mask) + + if len(merged_masks) > 0: + merged_masks = torch.stack(merged_masks, dim=0) + + return merged_masks + +def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, + threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): + if sam_model.is_auto_mode: + device = model_management.get_torch_device() + sam_model.safe_to.to_device(sam_model, device=device) + + try: + predictor = SamPredictor(sam_model) + image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + predictor.set_image(image, "RGB") + + total_masks = [] + + use_small_negative = mask_hint_use_negative == "Small" + + # seg_shape = segs[0] + segs = segs[1] + if detection_hint == "mask-points": + points = [] + plabs = [] + + for i in range(len(segs)): + bbox = segs[i].bbox + center = center_of_bbox(bbox) + points.append(center) + + # small point is background, big point is foreground + if use_small_negative and bbox[2] - bbox[0] < 10: + plabs.append(0) + else: + plabs.append(1) + + detected_masks = sam_predict(predictor, points, plabs, None, threshold) + total_masks += detected_masks + + else: + for i in range(len(segs)): + bbox = segs[i].bbox + center = center_of_bbox(bbox) + x1 = max(bbox[0] - bbox_expansion, 0) + y1 = max(bbox[1] - bbox_expansion, 0) + x2 = min(bbox[2] + bbox_expansion, image.shape[1]) + y2 = min(bbox[3] + bbox_expansion, image.shape[0]) + + dilated_bbox = [x1, y1, x2, y2] + + points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox, + mask_hint_threshold, use_small_negative, + mask_hint_use_negative) + + detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold) + + total_masks += detected_masks + + # merge every collected masks + mask = combine_masks2(total_masks) + + finally: + if sam_model.is_auto_mode: + sam_model.cpu() + + pass + + mask_working_device = torch.device("cpu") + + if mask is not None: + mask = mask.float() + mask = dilate_mask(mask.cpu().numpy(), dilation) + mask = torch.from_numpy(mask) + mask = mask.to(device=mask_working_device) + else: + # Extracting batch, height and width + height, width, _ = image.shape + mask = torch.zeros( + (height, width), dtype=torch.float32, device=mask_working_device + ) # empty mask + + stacked_masks = convert_and_stack_masks(total_masks) + + return (mask, merge_and_stack_masks(stacked_masks, group_size=3)) + +def tensor2mask(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t + if size[3] == 1: + return t[:,:,:,0] + elif size[3] == 4: + # Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior + if torch.min(t[:, :, :, 3]).item() != 1.: + return t[:,:,:,3] + return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:] + +def tensor2rgb(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t.unsqueeze(3).repeat(1, 1, 1, 3) + if size[3] == 1: + return t.repeat(1, 1, 1, 3) + elif size[3] == 4: + return t[:, :, :, :3] + else: + return t + +def tensor2rgba(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t.unsqueeze(3).repeat(1, 1, 1, 4) + elif size[3] == 1: + return t.repeat(1, 1, 1, 4) + elif size[3] == 3: + alpha_tensor = torch.ones((size[0], size[1], size[2], 1)) + return torch.cat((t, alpha_tensor), dim=3) + else: + return t diff --git a/scripts/r_masking/segs.py b/scripts/r_masking/segs.py new file mode 100644 index 0000000..60c22d7 --- /dev/null +++ b/scripts/r_masking/segs.py @@ -0,0 +1,22 @@ +def filter(segs, labels): + labels = set([label.strip() for label in labels]) + + if 'all' in labels: + return (segs, (segs[0], []), ) + else: + res_segs = [] + remained_segs = [] + + for x in segs[1]: + if x.label in labels: + res_segs.append(x) + elif 'eyes' in labels and x.label in ['left_eye', 'right_eye']: + res_segs.append(x) + elif 'eyebrows' in labels and x.label in ['left_eyebrow', 'right_eyebrow']: + res_segs.append(x) + elif 'pupils' in labels and x.label in ['left_pupil', 'right_pupil']: + res_segs.append(x) + else: + remained_segs.append(x) + + return ((segs[0], res_segs), (segs[0], remained_segs), ) diff --git a/scripts/r_masking/subcore.py b/scripts/r_masking/subcore.py new file mode 100644 index 0000000..cf7bf7d --- /dev/null +++ b/scripts/r_masking/subcore.py @@ -0,0 +1,117 @@ +import numpy as np +import cv2 +from PIL import Image + +import scripts.r_masking.core as core +from reactor_utils import tensor_to_pil + +try: + from ultralytics import YOLO +except Exception as e: + print(e) + + +def load_yolo(model_path: str): + try: + return YOLO(model_path) + except ModuleNotFoundError: + # https://github.com/ultralytics/ultralytics/issues/3856 + YOLO("yolov8n.pt") + return YOLO(model_path) + +def inference_bbox( + model, + image: Image.Image, + confidence: float = 0.3, + device: str = "", +): + pred = model(image, conf=confidence, device=device) + + bboxes = pred[0].boxes.xyxy.cpu().numpy() + cv2_image = np.array(image) + if len(cv2_image.shape) == 3: + cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing + else: + # Handle the grayscale image here + # For example, you might want to convert it to a 3-channel grayscale image for consistency: + cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_GRAY2BGR) + cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) + + segms = [] + for x0, y0, x1, y1 in bboxes: + cv2_mask = np.zeros(cv2_gray.shape, np.uint8) + cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1) + cv2_mask_bool = cv2_mask.astype(bool) + segms.append(cv2_mask_bool) + + n, m = bboxes.shape + if n == 0: + return [[], [], [], []] + + results = [[], [], [], []] + for i in range(len(bboxes)): + results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())]) + results[1].append(bboxes[i]) + results[2].append(segms[i]) + results[3].append(pred[0].boxes[i].conf.cpu().numpy()) + + return results + + +class UltraBBoxDetector: + bbox_model = None + + def __init__(self, bbox_model): + self.bbox_model = bbox_model + + def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): + drop_size = max(drop_size, 1) + detected_results = inference_bbox(self.bbox_model, tensor_to_pil(image), threshold) + segmasks = core.create_segmasks(detected_results) + + if dilation > 0: + segmasks = core.dilate_masks(segmasks, dilation) + + items = [] + h = image.shape[1] + w = image.shape[2] + + for x, label in zip(segmasks, detected_results[0]): + item_bbox = x[0] + item_mask = x[1] + + y1, x1, y2, x2 = item_bbox + + if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue + crop_region = core.make_crop_region(w, h, item_bbox, crop_factor) + + if detailer_hook is not None: + crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region) + + cropped_image = core.crop_image(image, crop_region) + cropped_mask = core.crop_ndarray2(item_mask, crop_region) + confidence = x[2] + # bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h) + + item = core.SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None) + + items.append(item) + + shape = image.shape[1], image.shape[2] + segs = shape, items + + if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): + segs = detailer_hook.post_detection(segs) + + return segs + + def detect_combined(self, image, threshold, dilation): + detected_results = inference_bbox(self.bbox_model, core.tensor2pil(image), threshold) + segmasks = core.create_segmasks(detected_results) + if dilation > 0: + segmasks = core.dilate_masks(segmasks, dilation) + + return core.combine_masks(segmasks) + + def setAux(self, x): + pass diff --git a/scripts/reactor_faceswap.py b/scripts/reactor_faceswap.py new file mode 100644 index 0000000..7e6c03e --- /dev/null +++ b/scripts/reactor_faceswap.py @@ -0,0 +1,185 @@ +import os, glob + +from PIL import Image + +import modules.scripts as scripts +# from modules.upscaler import Upscaler, UpscalerData +from modules import scripts, scripts_postprocessing +from modules.processing import ( + StableDiffusionProcessing, + StableDiffusionProcessingImg2Img, +) +from modules.shared import state +from scripts.reactor_logger import logger +from scripts.reactor_swapper import ( + swap_face, + swap_face_many, + get_current_faces_model, + analyze_faces, + half_det_size, + providers +) +import folder_paths +import comfy.model_management as model_management + + +def get_models(): + models_path = os.path.join(folder_paths.models_dir,"insightface/*") + models = glob.glob(models_path) + models = [x for x in models if x.endswith(".onnx") or x.endswith(".pth")] + return models + + +class FaceSwapScript(scripts.Script): + + def process( + self, + p: StableDiffusionProcessing, + img, + enable, + source_faces_index, + faces_index, + model, + swap_in_source, + swap_in_generated, + gender_source, + gender_target, + face_model, + faces_order, + face_boost_enabled, + face_restore_model, + face_restore_visibility, + codeformer_weight, + interpolation, + ): + self.enable = enable + if self.enable: + + self.source = img + self.swap_in_generated = swap_in_generated + self.gender_source = gender_source + self.gender_target = gender_target + self.model = model + self.face_model = face_model + self.faces_order = faces_order + self.face_boost_enabled = face_boost_enabled + self.face_restore_model = face_restore_model + self.face_restore_visibility = face_restore_visibility + self.codeformer_weight = codeformer_weight + self.interpolation = interpolation + self.source_faces_index = [ + int(x) for x in source_faces_index.strip(",").split(",") if x.isnumeric() + ] + self.faces_index = [ + int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() + ] + if len(self.source_faces_index) == 0: + self.source_faces_index = [0] + if len(self.faces_index) == 0: + self.faces_index = [0] + + if self.gender_source is None or self.gender_source == "no": + self.gender_source = 0 + elif self.gender_source == "female": + self.gender_source = 1 + elif self.gender_source == "male": + self.gender_source = 2 + + if self.gender_target is None or self.gender_target == "no": + self.gender_target = 0 + elif self.gender_target == "female": + self.gender_target = 1 + elif self.gender_target == "male": + self.gender_target = 2 + + # if self.source is not None: + if isinstance(p, StableDiffusionProcessingImg2Img) and swap_in_source: + logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index) + + if len(p.init_images) == 1: + + result = swap_face( + self.source, + p.init_images[0], + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + gender_source=self.gender_source, + gender_target=self.gender_target, + face_model=self.face_model, + faces_order=self.faces_order, + face_boost_enabled=self.face_boost_enabled, + face_restore_model=self.face_restore_model, + face_restore_visibility=self.face_restore_visibility, + codeformer_weight=self.codeformer_weight, + interpolation=self.interpolation, + ) + p.init_images[0] = result + + # for i in range(len(p.init_images)): + # if state.interrupted or model_management.processing_interrupted(): + # logger.status("Interrupted by User") + # break + # if len(p.init_images) > 1: + # logger.status(f"Swap in %s", i) + # result = swap_face( + # self.source, + # p.init_images[i], + # source_faces_index=self.source_faces_index, + # faces_index=self.faces_index, + # model=self.model, + # gender_source=self.gender_source, + # gender_target=self.gender_target, + # face_model=self.face_model, + # ) + # p.init_images[i] = result + + elif len(p.init_images) > 1: + result = swap_face_many( + self.source, + p.init_images, + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + gender_source=self.gender_source, + gender_target=self.gender_target, + face_model=self.face_model, + faces_order=self.faces_order, + face_boost_enabled=self.face_boost_enabled, + face_restore_model=self.face_restore_model, + face_restore_visibility=self.face_restore_visibility, + codeformer_weight=self.codeformer_weight, + interpolation=self.interpolation, + ) + p.init_images = result + + logger.status("--Done!--") + # else: + # logger.error(f"Please provide a source face") + + def postprocess_batch(self, p, *args, **kwargs): + if self.enable: + images = kwargs["images"] + + def postprocess_image(self, p, script_pp: scripts.PostprocessImageArgs, *args): + if self.enable and self.swap_in_generated: + if self.source is not None: + logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index) + image: Image.Image = script_pp.image + result = swap_face( + self.source, + image, + source_faces_index=self.source_faces_index, + faces_index=self.faces_index, + model=self.model, + upscale_options=self.upscale_options, + gender_source=self.gender_source, + gender_target=self.gender_target, + ) + try: + pp = scripts_postprocessing.PostprocessedImage(result) + pp.info = {} + p.extra_generation_params.update(pp.info) + script_pp.image = pp.image + except: + logger.error(f"Cannot create a result image") diff --git a/scripts/reactor_logger.py b/scripts/reactor_logger.py new file mode 100644 index 0000000..f64e433 --- /dev/null +++ b/scripts/reactor_logger.py @@ -0,0 +1,47 @@ +import logging +import copy +import sys + +from modules import shared +from reactor_utils import addLoggingLevel + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "STATUS": "\033[38;5;173m", # Calm ORANGE + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +# Create a new logger +logger = logging.getLogger("ReActor") +logger.propagate = False + +# Add Custom Level +# logging.addLevelName(logging.INFO, "STATUS") +addLoggingLevel("STATUS", logging.INFO + 5) + +# Add handler if we don't have one. +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s",datefmt="%H:%M:%S") + ) + logger.addHandler(handler) + +# Configure logger +loglevel_string = getattr(shared.cmd_opts, "reactor_loglevel", "INFO") +loglevel = getattr(logging, loglevel_string.upper(), "info") +logger.setLevel(loglevel) diff --git a/scripts/reactor_swapper.py b/scripts/reactor_swapper.py new file mode 100644 index 0000000..6db5dfc --- /dev/null +++ b/scripts/reactor_swapper.py @@ -0,0 +1,572 @@ +import os +import shutil +from typing import List, Union + +import cv2 +import numpy as np +from PIL import Image + +import insightface +from insightface.app.common import Face +# try: +# import torch.cuda as cuda +# except: +# cuda = None +import torch + +import folder_paths +import comfy.model_management as model_management +from modules.shared import state + +from scripts.reactor_logger import logger +from reactor_utils import ( + move_path, + get_image_md5hash, +) +from scripts.r_faceboost import swapper, restorer + +import warnings + +np.warnings = warnings +np.warnings.filterwarnings('ignore') + +# PROVIDERS +try: + if torch.cuda.is_available(): + providers = ["CUDAExecutionProvider"] + elif torch.backends.mps.is_available(): + providers = ["CoreMLExecutionProvider"] + elif hasattr(torch,'dml') or hasattr(torch,'privateuseone'): + providers = ["ROCMExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] +except Exception as e: + logger.debug(f"ExecutionProviderError: {e}.\nEP is set to CPU.") + providers = ["CPUExecutionProvider"] +# if cuda is not None: +# if cuda.is_available(): +# providers = ["CUDAExecutionProvider"] +# else: +# providers = ["CPUExecutionProvider"] +# else: +# providers = ["CPUExecutionProvider"] + +models_path_old = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") +insightface_path_old = os.path.join(models_path_old, "insightface") +insightface_models_path_old = os.path.join(insightface_path_old, "models") + +models_path = folder_paths.models_dir +insightface_path = os.path.join(models_path, "insightface") +insightface_models_path = os.path.join(insightface_path, "models") + +if os.path.exists(models_path_old): + move_path(insightface_models_path_old, insightface_models_path) + move_path(insightface_path_old, insightface_path) + move_path(models_path_old, models_path) +if os.path.exists(insightface_path) and os.path.exists(insightface_path_old): + shutil.rmtree(insightface_path_old) + shutil.rmtree(models_path_old) + + +FS_MODEL = None +CURRENT_FS_MODEL_PATH = None + +ANALYSIS_MODELS = { + "640": None, + "320": None, +} + +SOURCE_FACES = None +SOURCE_IMAGE_HASH = None +TARGET_FACES = None +TARGET_IMAGE_HASH = None +TARGET_FACES_LIST = [] +TARGET_IMAGE_LIST_HASH = [] + +def unload_model(model): + if model is not None: + # check if model has unload method + # if "unload" in model: + # model.unload() + # if "model_unload" in model: + # model.model_unload() + del model + return None + +def unload_all_models(): + global FS_MODEL, CURRENT_FS_MODEL_PATH + FS_MODEL = unload_model(FS_MODEL) + ANALYSIS_MODELS["320"] = unload_model(ANALYSIS_MODELS["320"]) + ANALYSIS_MODELS["640"] = unload_model(ANALYSIS_MODELS["640"]) + +def get_current_faces_model(): + global SOURCE_FACES + return SOURCE_FACES + +def getAnalysisModel(det_size = (640, 640)): + global ANALYSIS_MODELS + ANALYSIS_MODEL = ANALYSIS_MODELS[str(det_size[0])] + if ANALYSIS_MODEL is None: + ANALYSIS_MODEL = insightface.app.FaceAnalysis( + name="buffalo_l", providers=providers, root=insightface_path + ) + ANALYSIS_MODEL.prepare(ctx_id=0, det_size=det_size) + ANALYSIS_MODELS[str(det_size[0])] = ANALYSIS_MODEL + return ANALYSIS_MODEL + +def getFaceSwapModel(model_path: str): + global FS_MODEL, CURRENT_FS_MODEL_PATH + if FS_MODEL is None or CURRENT_FS_MODEL_PATH is None or CURRENT_FS_MODEL_PATH != model_path: + CURRENT_FS_MODEL_PATH = model_path + FS_MODEL = unload_model(FS_MODEL) + FS_MODEL = insightface.model_zoo.get_model(model_path, providers=providers) + + return FS_MODEL + + +def sort_by_order(face, order: str): + if order == "left-right": + return sorted(face, key=lambda x: x.bbox[0]) + if order == "right-left": + return sorted(face, key=lambda x: x.bbox[0], reverse = True) + if order == "top-bottom": + return sorted(face, key=lambda x: x.bbox[1]) + if order == "bottom-top": + return sorted(face, key=lambda x: x.bbox[1], reverse = True) + if order == "small-large": + return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) + # if order == "large-small": + # return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True) + # by default "large-small": + return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True) + +def get_face_gender( + face, + face_index, + gender_condition, + operated: str, + order: str, +): + gender = [ + x.sex + for x in face + ] + gender.reverse() + # If index is outside of bounds, return None, avoid exception + if face_index >= len(gender): + logger.status("Requested face index (%s) is out of bounds (max available index is %s)", face_index, len(gender)) + return None, 0 + face_gender = gender[face_index] + logger.status("%s Face %s: Detected Gender -%s-", operated, face_index, face_gender) + if (gender_condition == 1 and face_gender == "F") or (gender_condition == 2 and face_gender == "M"): + logger.status("OK - Detected Gender matches Condition") + try: + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 0 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 0 + except IndexError: + return None, 0 + else: + logger.status("WRONG - Detected Gender doesn't match Condition") + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 1 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 1 + +def half_det_size(det_size): + logger.status("Trying to halve 'det_size' parameter") + return (det_size[0] // 2, det_size[1] // 2) + +def analyze_faces(img_data: np.ndarray, det_size=(640, 640)): + face_analyser = getAnalysisModel(det_size) + faces = face_analyser.get(img_data) + + # Try halving det_size if no faces are found + if len(faces) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return analyze_faces(img_data, det_size_half) + + return faces + +def get_face_single(img_data: np.ndarray, face, face_index=0, det_size=(640, 640), gender_source=0, gender_target=0, order="large-small"): + + buffalo_path = os.path.join(insightface_models_path, "buffalo_l.zip") + if os.path.exists(buffalo_path): + os.remove(buffalo_path) + + if gender_source != 0: + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + return get_face_gender(face,face_index,gender_source,"Source", order) + + if gender_target != 0: + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + return get_face_gender(face,face_index,gender_target,"Target", order) + + if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + det_size_half = half_det_size(det_size) + return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order) + + try: + faces_sorted = sort_by_order(face, order) + return faces_sorted[face_index], 0 + # return sorted(face, key=lambda x: x.bbox[0])[face_index], 0 + except IndexError: + return None, 0 + + +def swap_face( + source_img: Union[Image.Image, None], + target_img: Image.Image, + model: Union[str, None] = None, + source_faces_index: List[int] = [0], + faces_index: List[int] = [0], + gender_source: int = 0, + gender_target: int = 0, + face_model: Union[Face, None] = None, + faces_order: List = ["large-small", "large-small"], + face_boost_enabled: bool = False, + face_restore_model = None, + face_restore_visibility: int = 1, + codeformer_weight: float = 0.5, + interpolation: str = "Bicubic", +): + global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH + result_image = target_img + + if model is not None: + + if isinstance(source_img, str): # source_img is a base64 string + import base64, io + if 'base64,' in source_img: # check if the base64 string has a data URL scheme + # split the base64 string to get the actual base64 encoded image data + base64_data = source_img.split('base64,')[-1] + # decode base64 string to bytes + img_bytes = base64.b64decode(base64_data) + else: + # if no data URL scheme, just decode + img_bytes = base64.b64decode(source_img) + + source_img = Image.open(io.BytesIO(img_bytes)) + + target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) + + if source_img is not None: + + source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) + + source_image_md5hash = get_image_md5hash(source_img) + + if SOURCE_IMAGE_HASH is None: + SOURCE_IMAGE_HASH = source_image_md5hash + source_image_same = False + else: + source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False + if not source_image_same: + SOURCE_IMAGE_HASH = source_image_md5hash + + logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH) + logger.info("Source Image the Same? %s", source_image_same) + + if SOURCE_FACES is None or not source_image_same: + logger.status("Analyzing Source Image...") + source_faces = analyze_faces(source_img) + SOURCE_FACES = source_faces + elif source_image_same: + logger.status("Using Hashed Source Face(s) Model...") + source_faces = SOURCE_FACES + + elif face_model is not None: + + source_faces_index = [0] + logger.status("Using Loaded Source Face Model...") + source_face_model = [face_model] + source_faces = source_face_model + + else: + logger.error("Cannot detect any Source") + + if source_faces is not None: + + target_image_md5hash = get_image_md5hash(target_img) + + if TARGET_IMAGE_HASH is None: + TARGET_IMAGE_HASH = target_image_md5hash + target_image_same = False + else: + target_image_same = True if TARGET_IMAGE_HASH == target_image_md5hash else False + if not target_image_same: + TARGET_IMAGE_HASH = target_image_md5hash + + logger.info("Target Image MD5 Hash = %s", TARGET_IMAGE_HASH) + logger.info("Target Image the Same? %s", target_image_same) + + if TARGET_FACES is None or not target_image_same: + logger.status("Analyzing Target Image...") + target_faces = analyze_faces(target_img) + TARGET_FACES = target_faces + elif target_image_same: + logger.status("Using Hashed Target Face(s) Model...") + target_faces = TARGET_FACES + + # No use in trying to swap faces if no faces are found, enhancement + if len(target_faces) == 0: + logger.status("Cannot detect any Target, skipping swapping...") + return result_image + + if source_img is not None: + # separated management of wrong_gender between source and target, enhancement + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source, order=faces_order[1]) + else: + # source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]] + source_face = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)[source_faces_index[0]] + src_wrong_gender = 0 + + if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index): + logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.') + elif source_face is not None: + result = target_img + model_path = model_path = os.path.join(insightface_path, model) + face_swapper = getFaceSwapModel(model_path) + + source_face_idx = 0 + + for face_num in faces_index: + # No use in trying to swap faces if no further faces are found, enhancement + if face_num >= len(target_faces): + logger.status("Checked all existing target faces, skipping swapping...") + break + + if len(source_faces_index) > 1 and source_face_idx > 0: + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source, order=faces_order[1]) + source_face_idx += 1 + + if source_face is not None and src_wrong_gender == 0: + target_face, wrong_gender = get_face_single(target_img, target_faces, face_index=face_num, gender_target=gender_target, order=faces_order[0]) + if target_face is not None and wrong_gender == 0: + logger.status(f"Swapping...") + if face_boost_enabled: + logger.status(f"Face Boost is enabled") + bgr_fake, M = face_swapper.get(result, target_face, source_face, paste_back=False) + bgr_fake, scale = restorer.get_restored_face(bgr_fake, face_restore_model, face_restore_visibility, codeformer_weight, interpolation) + M *= scale + result = swapper.in_swap(target_img, bgr_fake, M) + else: + # logger.status(f"Swapping as-is") + result = face_swapper.get(result, target_face, source_face) + elif wrong_gender == 1: + wrong_gender = 0 + # Keep searching for other faces if wrong gender is detected, enhancement + #if source_face_idx == len(source_faces_index): + # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + # return result_image + logger.status("Wrong target gender detected") + continue + else: + logger.status(f"No target face found for {face_num}") + elif src_wrong_gender == 1: + src_wrong_gender = 0 + # Keep searching for other faces if wrong gender is detected, enhancement + #if source_face_idx == len(source_faces_index): + # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + # return result_image + logger.status("Wrong source gender detected") + continue + else: + logger.status(f"No source face found for face number {source_face_idx}.") + + result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) + + else: + logger.status("No source face(s) in the provided Index") + else: + logger.status("No source face(s) found") + return result_image + +def swap_face_many( + source_img: Union[Image.Image, None], + target_imgs: List[Image.Image], + model: Union[str, None] = None, + source_faces_index: List[int] = [0], + faces_index: List[int] = [0], + gender_source: int = 0, + gender_target: int = 0, + face_model: Union[Face, None] = None, + faces_order: List = ["large-small", "large-small"], + face_boost_enabled: bool = False, + face_restore_model = None, + face_restore_visibility: int = 1, + codeformer_weight: float = 0.5, + interpolation: str = "Bicubic", +): + global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH, TARGET_FACES_LIST, TARGET_IMAGE_LIST_HASH + result_images = target_imgs + + if model is not None: + + if isinstance(source_img, str): # source_img is a base64 string + import base64, io + if 'base64,' in source_img: # check if the base64 string has a data URL scheme + # split the base64 string to get the actual base64 encoded image data + base64_data = source_img.split('base64,')[-1] + # decode base64 string to bytes + img_bytes = base64.b64decode(base64_data) + else: + # if no data URL scheme, just decode + img_bytes = base64.b64decode(source_img) + + source_img = Image.open(io.BytesIO(img_bytes)) + + target_imgs = [cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) for target_img in target_imgs] + + if source_img is not None: + + source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) + + source_image_md5hash = get_image_md5hash(source_img) + + if SOURCE_IMAGE_HASH is None: + SOURCE_IMAGE_HASH = source_image_md5hash + source_image_same = False + else: + source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False + if not source_image_same: + SOURCE_IMAGE_HASH = source_image_md5hash + + logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH) + logger.info("Source Image the Same? %s", source_image_same) + + if SOURCE_FACES is None or not source_image_same: + logger.status("Analyzing Source Image...") + source_faces = analyze_faces(source_img) + SOURCE_FACES = source_faces + elif source_image_same: + logger.status("Using Hashed Source Face(s) Model...") + source_faces = SOURCE_FACES + + elif face_model is not None: + + source_faces_index = [0] + logger.status("Using Loaded Source Face Model...") + source_face_model = [face_model] + source_faces = source_face_model + + else: + logger.error("Cannot detect any Source") + + if source_faces is not None: + + target_faces = [] + for i, target_img in enumerate(target_imgs): + if state.interrupted or model_management.processing_interrupted(): + logger.status("Interrupted by User") + break + + target_image_md5hash = get_image_md5hash(target_img) + if len(TARGET_IMAGE_LIST_HASH) == 0: + TARGET_IMAGE_LIST_HASH = [target_image_md5hash] + target_image_same = False + elif len(TARGET_IMAGE_LIST_HASH) == i: + TARGET_IMAGE_LIST_HASH.append(target_image_md5hash) + target_image_same = False + else: + target_image_same = True if TARGET_IMAGE_LIST_HASH[i] == target_image_md5hash else False + if not target_image_same: + TARGET_IMAGE_LIST_HASH[i] = target_image_md5hash + + logger.info("(Image %s) Target Image MD5 Hash = %s", i, TARGET_IMAGE_LIST_HASH[i]) + logger.info("(Image %s) Target Image the Same? %s", i, target_image_same) + + if len(TARGET_FACES_LIST) == 0: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST = [target_face] + elif len(TARGET_FACES_LIST) == i and not target_image_same: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST.append(target_face) + elif len(TARGET_FACES_LIST) != i and not target_image_same: + logger.status(f"Analyzing Target Image {i}...") + target_face = analyze_faces(target_img) + TARGET_FACES_LIST[i] = target_face + elif target_image_same: + logger.status("(Image %s) Using Hashed Target Face(s) Model...", i) + target_face = TARGET_FACES_LIST[i] + + + # logger.status(f"Analyzing Target Image {i}...") + # target_face = analyze_faces(target_img) + if target_face is not None: + target_faces.append(target_face) + + # No use in trying to swap faces if no faces are found, enhancement + if len(target_faces) == 0: + logger.status("Cannot detect any Target, skipping swapping...") + return result_images + + if source_img is not None: + # separated management of wrong_gender between source and target, enhancement + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source, order=faces_order[1]) + else: + # source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]] + source_face = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)[source_faces_index[0]] + src_wrong_gender = 0 + + if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index): + logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.') + elif source_face is not None: + results = target_imgs + model_path = model_path = os.path.join(insightface_path, model) + face_swapper = getFaceSwapModel(model_path) + + source_face_idx = 0 + + for face_num in faces_index: + # No use in trying to swap faces if no further faces are found, enhancement + if face_num >= len(target_faces): + logger.status("Checked all existing target faces, skipping swapping...") + break + + if len(source_faces_index) > 1 and source_face_idx > 0: + source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source, order=faces_order[1]) + source_face_idx += 1 + + if source_face is not None and src_wrong_gender == 0: + # Reading results to make current face swap on a previous face result + for i, (target_img, target_face) in enumerate(zip(results, target_faces)): + target_face_single, wrong_gender = get_face_single(target_img, target_face, face_index=face_num, gender_target=gender_target, order=faces_order[0]) + if target_face_single is not None and wrong_gender == 0: + result = target_img + logger.status(f"Swapping {i}...") + if face_boost_enabled: + logger.status(f"Face Boost is enabled") + bgr_fake, M = face_swapper.get(target_img, target_face_single, source_face, paste_back=False) + bgr_fake, scale = restorer.get_restored_face(bgr_fake, face_restore_model, face_restore_visibility, codeformer_weight, interpolation) + M *= scale + result = swapper.in_swap(target_img, bgr_fake, M) + else: + # logger.status(f"Swapping as-is") + result = face_swapper.get(target_img, target_face_single, source_face) + results[i] = result + elif wrong_gender == 1: + wrong_gender = 0 + logger.status("Wrong target gender detected") + continue + else: + logger.status(f"No target face found for {face_num}") + elif src_wrong_gender == 1: + src_wrong_gender = 0 + logger.status("Wrong source gender detected") + continue + else: + logger.status(f"No source face found for face number {source_face_idx}.") + + result_images = [Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) for result in results] + + else: + logger.status("No source face(s) in the provided Index") + else: + logger.status("No source face(s) found") + return result_images diff --git a/scripts/reactor_version.py b/scripts/reactor_version.py new file mode 100644 index 0000000..b4e6267 --- /dev/null +++ b/scripts/reactor_version.py @@ -0,0 +1,13 @@ +app_title = "ReActor Node for ComfyUI" +version_flag = "v0.5.2-a2" + +COLORS = { + "CYAN": "\033[0;36m", # CYAN + "ORANGE": "\033[38;5;173m", # Calm ORANGE + "GREEN": "\033[0;32m", # GREEN + "YELLOW": "\033[0;33m", # YELLOW + "RED": "\033[0;91m", # RED + "0": "\033[0m", # RESET COLOR +} + +print(f"{COLORS['YELLOW']}[ReActor]{COLORS['0']} - {COLORS['ORANGE']}STATUS{COLORS['0']} - {COLORS['GREEN']}Running {version_flag} in ComfyUI{COLORS['0']}") -- cgit v1.2.3