diff options
author | Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> | 2025-01-17 11:06:44 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 11:06:44 +0000 |
commit | e6bd5af6a8e306a1cdef63402a77a980a04ad6e1 (patch) | |
tree | d0732226bbc22feedad9e834b2218d7d0b0eff54 /scripts | |
parent | 495ffc4777522e40941753e3b1b79c02f84b25b4 (diff) | |
download | Comfyui-reactor-node-main.tar.gz |
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/__init__.py | 0 | ||||
-rw-r--r-- | scripts/r_archs/__init__.py | 0 | ||||
-rw-r--r-- | scripts/r_archs/codeformer_arch.py | 278 | ||||
-rw-r--r-- | scripts/r_archs/vqgan_arch.py | 437 | ||||
-rw-r--r-- | scripts/r_faceboost/__init__.py | 0 | ||||
-rw-r--r-- | scripts/r_faceboost/restorer.py | 130 | ||||
-rw-r--r-- | scripts/r_faceboost/swapper.py | 42 | ||||
-rw-r--r-- | scripts/r_masking/__init__.py | 0 | ||||
-rw-r--r-- | scripts/r_masking/core.py | 647 | ||||
-rw-r--r-- | scripts/r_masking/segs.py | 22 | ||||
-rw-r--r-- | scripts/r_masking/subcore.py | 117 | ||||
-rw-r--r-- | scripts/reactor_faceswap.py | 185 | ||||
-rw-r--r-- | scripts/reactor_logger.py | 47 | ||||
-rw-r--r-- | scripts/reactor_swapper.py | 572 | ||||
-rw-r--r-- | scripts/reactor_version.py | 13 |
15 files changed, 2490 insertions, 0 deletions
diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/__init__.py diff --git a/scripts/r_archs/__init__.py b/scripts/r_archs/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/r_archs/__init__.py diff --git a/scripts/r_archs/codeformer_arch.py b/scripts/r_archs/codeformer_arch.py new file mode 100644 index 0000000..588ef69 --- /dev/null +++ b/scripts/r_archs/codeformer_arch.py @@ -0,0 +1,278 @@ +import math
+import numpy as np
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from typing import Optional, List
+
+from scripts.r_archs.vqgan_arch import *
+from r_basicsr.utils import get_root_logger
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model - MLP
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.norm2 = nn.LayerNorm(embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+
+ # self attention
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+
+ # ffn
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+class Fuse_sft_block(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super().__init__()
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+ self.scale = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ self.shift = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+ scale = self.scale(enc_feat)
+ shift = self.shift(enc_feat)
+ residual = w * (dec_feat * scale + shift)
+ out = dec_feat + residual
+ return out
+
+
+@ARCH_REGISTRY.register()
+class CodeFormer(VQAutoEncoder):
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+ codebook_size=1024, latent_size=256,
+ connect_list=['32', '64', '128', '256'],
+ fix_modules=['quantize','generator']):
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+
+ if fix_modules is not None:
+ for module in fix_modules:
+ for param in getattr(self, module).parameters():
+ param.requires_grad = False
+
+ self.connect_list = connect_list
+ self.n_layers = n_layers
+ self.dim_embd = dim_embd
+ self.dim_mlp = dim_embd*2
+
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+ self.feat_emb = nn.Linear(256, self.dim_embd)
+
+ # transformer
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ for _ in range(self.n_layers)])
+
+ # logits_predict head
+ self.idx_pred_layer = nn.Sequential(
+ nn.LayerNorm(dim_embd),
+ nn.Linear(dim_embd, codebook_size, bias=False))
+
+ self.channels = {
+ '16': 512,
+ '32': 256,
+ '64': 256,
+ '128': 128,
+ '256': 128,
+ '512': 64,
+ }
+
+ # after second residual block for > 16, before attn layer for ==16
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
+ # after first residual block for > 16, before attn layer for ==16
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
+
+ # fuse_convs_dict
+ self.fuse_convs_dict = nn.ModuleDict()
+ for f_size in self.connect_list:
+ in_ch = self.channels[f_size]
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+ # ################### Encoder #####################
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+ lq_feat = x
+ # ################# Transformer ###################
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+ # BCHW -> BC(HW) -> (HW)BC
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+ query_emb = feat_emb
+ # Transformer encoder
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb)
+
+ # output logits
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+ if code_only: # for training stage II
+ # logits doesn't need softmax before cross_entropy loss
+ return logits, lq_feat
+
+ # ################# Quantization ###################
+ # if self.training:
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+ # # b(hw)c -> bc(hw) -> bchw
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+ # ------------
+ soft_one_hot = F.softmax(logits, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+ # preserve gradients
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+ if detach_16:
+ quant_feat = quant_feat.detach() # for training stage III
+ if adain:
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+ # ################## Generator ####################
+ x = quant_feat
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+ for i, block in enumerate(self.generator.blocks):
+ x = block(x)
+ if i in fuse_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ if w>0:
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+ out = x
+ # logits doesn't need softmax before cross_entropy loss
+ return out, logits, lq_feat
+
\ No newline at end of file diff --git a/scripts/r_archs/vqgan_arch.py b/scripts/r_archs/vqgan_arch.py new file mode 100644 index 0000000..50b3712 --- /dev/null +++ b/scripts/r_archs/vqgan_arch.py @@ -0,0 +1,437 @@ +'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from r_basicsr.utils import get_root_logger
+from r_basicsr.utils.registry import ARCH_REGISTRY
+
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+ return x*torch.sigmoid(x)
+
+
+# Define VQVAE classes
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.emb_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+ mean_distance = torch.mean(d)
+ # find closest encodings
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+ # [0-1], higher score, higher confidence
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, {
+ "perplexity": perplexity,
+ "min_encodings": min_encodings,
+ "min_encoding_indices": min_encoding_indices,
+ "min_encoding_scores": min_encoding_scores,
+ "mean_distance": mean_distance
+ }
+
+ def get_codebook_feat(self, indices, shape):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1,1)
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+ super().__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
+ self.embed = nn.Embedding(codebook_size, emb_dim)
+
+ def forward(self, z):
+ hard = self.straight_through if self.training else True
+
+ logits = self.proj(z)
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+ return z_q, diff, {
+ "min_encoding_indices": min_encoding_indices
+ }
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1)
+ k = k.reshape(b, c, h*w)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.attn_resolutions = attn_resolutions
+
+ curr_res = self.resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+
+ blocks = []
+ # initial convultion
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+ # residual and downsampling blocks, with attention on smaller res (16x16)
+ for i in range(self.num_resolutions):
+ block_in_ch = nf * in_ch_mult[i]
+ block_out_ch = nf * ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+ if curr_res in attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != self.num_resolutions - 1:
+ blocks.append(Downsample(block_in_ch))
+ curr_res = curr_res // 2
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ # normalise and convert to latent size
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class Generator(nn.Module):
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.ch_mult = ch_mult
+ self.num_resolutions = len(self.ch_mult)
+ self.num_res_blocks = res_blocks
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.in_channels = emb_dim
+ self.out_channels = 3
+ block_in_ch = self.nf * self.ch_mult[-1]
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+ blocks = []
+ # initial conv
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ for i in reversed(range(self.num_resolutions)):
+ block_out_ch = self.nf * self.ch_mult[i]
+
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+
+ if curr_res in self.attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != 0:
+ blocks.append(Upsample(block_in_ch))
+ curr_res = curr_res * 2
+
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+ self.blocks = nn.ModuleList(blocks)
+
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta #0.25
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight
+ )
+ self.generator = Generator(
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_ema' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
+ else:
+ raise ValueError(f'Wrong params!')
+
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ x = self.generator(quant)
+ return x, codebook_loss, quant_stats
+
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+ super().__init__()
+
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n, 8)
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n_layers, 8)
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_d' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ else:
+ raise ValueError(f'Wrong params!')
+
+ def forward(self, x):
+ return self.main(x)
+
\ No newline at end of file diff --git a/scripts/r_faceboost/__init__.py b/scripts/r_faceboost/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/r_faceboost/__init__.py diff --git a/scripts/r_faceboost/restorer.py b/scripts/r_faceboost/restorer.py new file mode 100644 index 0000000..60c0165 --- /dev/null +++ b/scripts/r_faceboost/restorer.py @@ -0,0 +1,130 @@ +import sys
+import cv2
+import numpy as np
+import torch
+from torchvision.transforms.functional import normalize
+
+try:
+ import torch.cuda as cuda
+except:
+ cuda = None
+
+import comfy.utils
+import folder_paths
+import comfy.model_management as model_management
+
+from scripts.reactor_logger import logger
+from r_basicsr.utils.registry import ARCH_REGISTRY
+from r_chainner import model_loading
+from reactor_utils import (
+ tensor2img,
+ img2tensor,
+ set_ort_session,
+ prepare_cropped_face,
+ normalize_cropped_face
+)
+
+
+if cuda is not None:
+ if cuda.is_available():
+ providers = ["CUDAExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+else:
+ providers = ["CPUExecutionProvider"]
+
+
+def get_restored_face(cropped_face,
+ face_restore_model,
+ face_restore_visibility,
+ codeformer_weight,
+ interpolation: str = "Bicubic"):
+
+ if interpolation == "Bicubic":
+ interpolate = cv2.INTER_CUBIC
+ elif interpolation == "Bilinear":
+ interpolate = cv2.INTER_LINEAR
+ elif interpolation == "Nearest":
+ interpolate = cv2.INTER_NEAREST
+ elif interpolation == "Lanczos":
+ interpolate = cv2.INTER_LANCZOS4
+
+ face_size = 512
+ if "1024" in face_restore_model.lower():
+ face_size = 1024
+ elif "2048" in face_restore_model.lower():
+ face_size = 2048
+
+ scale = face_size / cropped_face.shape[0]
+
+ logger.status(f"Boosting the Face with {face_restore_model} | Face Size is set to {face_size} with Scale Factor = {scale} and '{interpolation}' interpolation")
+
+ cropped_face = cv2.resize(cropped_face, (face_size, face_size), interpolation=interpolate)
+
+ # For upscaling the base 128px face, I found bicubic interpolation to be the best compromise targeting antialiasing
+ # and detail preservation. Nearest is predictably unusable, Linear produces too much aliasing, and Lanczos produces
+ # too many hallucinations and artifacts/fringing.
+
+ model_path = folder_paths.get_full_path("facerestore_models", face_restore_model)
+ device = model_management.get_torch_device()
+
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+ try:
+
+ with torch.no_grad():
+
+ if ".onnx" in face_restore_model: # ONNX models
+
+ ort_session = set_ort_session(model_path, providers=providers)
+ ort_session_inputs = {}
+ facerestore_model = ort_session
+
+ for ort_session_input in ort_session.get_inputs():
+ if ort_session_input.name == "input":
+ cropped_face_prep = prepare_cropped_face(cropped_face)
+ ort_session_inputs[ort_session_input.name] = cropped_face_prep
+ if ort_session_input.name == "weight":
+ weight = np.array([1], dtype=np.double)
+ ort_session_inputs[ort_session_input.name] = weight
+
+ output = ort_session.run(None, ort_session_inputs)[0][0]
+ restored_face = normalize_cropped_face(output)
+
+ else: # PTH models
+
+ if "codeformer" in face_restore_model.lower():
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
+ dim_embd=512,
+ codebook_size=1024,
+ n_head=8,
+ n_layers=9,
+ connect_list=["32", "64", "128", "256"],
+ ).to(device)
+ checkpoint = torch.load(model_path)["params_ema"]
+ codeformer_net.load_state_dict(checkpoint)
+ facerestore_model = codeformer_net.eval()
+ else:
+ sd = comfy.utils.load_torch_file(model_path, safe_load=True)
+ facerestore_model = model_loading.load_state_dict(sd).eval()
+ facerestore_model.to(device)
+
+ output = facerestore_model(cropped_face_t, w=codeformer_weight)[
+ 0] if "codeformer" in face_restore_model.lower() else facerestore_model(cropped_face_t)[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+
+ del output
+ torch.cuda.empty_cache()
+
+ except Exception as error:
+
+ print(f"\tFailed inference: {error}", file=sys.stderr)
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ if face_restore_visibility < 1:
+ restored_face = cropped_face * (1 - face_restore_visibility) + restored_face * face_restore_visibility
+
+ restored_face = restored_face.astype("uint8")
+ return restored_face, scale
diff --git a/scripts/r_faceboost/swapper.py b/scripts/r_faceboost/swapper.py new file mode 100644 index 0000000..e0467cf --- /dev/null +++ b/scripts/r_faceboost/swapper.py @@ -0,0 +1,42 @@ +import cv2
+import numpy as np
+
+# The following code is almost entirely copied from INSwapper; the only change here is that we want to use Lanczos
+# interpolation for the warpAffine call. Now that the face has been restored, Lanczos represents a good compromise
+# whether the restored face needs to be upscaled or downscaled.
+def in_swap(img, bgr_fake, M):
+ target_img = img
+ IM = cv2.invertAffineTransform(M)
+ img_white = np.full((bgr_fake.shape[0], bgr_fake.shape[1]), 255, dtype=np.float32)
+
+ # Note the use of bicubic here; this is functionally the only change from the source code
+ bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0, flags=cv2.INTER_CUBIC)
+
+ img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
+ img_white[img_white > 20] = 255
+ img_mask = img_white
+ mask_h_inds, mask_w_inds = np.where(img_mask == 255)
+ mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
+ mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
+ mask_size = int(np.sqrt(mask_h * mask_w))
+ k = max(mask_size // 10, 10)
+ # k = max(mask_size//20, 6)
+ # k = 6
+ kernel = np.ones((k, k), np.uint8)
+ img_mask = cv2.erode(img_mask, kernel, iterations=1)
+ kernel = np.ones((2, 2), np.uint8)
+ k = max(mask_size // 20, 5)
+ # k = 3
+ # k = 3
+ kernel_size = (k, k)
+ blur_size = tuple(2 * i + 1 for i in kernel_size)
+ img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
+ k = 5
+ kernel_size = (k, k)
+ blur_size = tuple(2 * i + 1 for i in kernel_size)
+ img_mask /= 255
+ # img_mask = fake_diff
+ img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1])
+ fake_merged = img_mask * bgr_fake + (1 - img_mask) * target_img.astype(np.float32)
+ fake_merged = fake_merged.astype(np.uint8)
+ return fake_merged
diff --git a/scripts/r_masking/__init__.py b/scripts/r_masking/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/r_masking/__init__.py diff --git a/scripts/r_masking/core.py b/scripts/r_masking/core.py new file mode 100644 index 0000000..36862e1 --- /dev/null +++ b/scripts/r_masking/core.py @@ -0,0 +1,647 @@ +import numpy as np
+import cv2
+import torch
+import torchvision.transforms.functional as TF
+
+import sys as _sys
+from keyword import iskeyword as _iskeyword
+from operator import itemgetter as _itemgetter
+
+from segment_anything import SamPredictor
+
+from comfy import model_management
+
+
+################################################################################
+### namedtuple
+################################################################################
+
+try:
+ from _collections import _tuplegetter
+except ImportError:
+ _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc)
+
+def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None):
+ """Returns a new subclass of tuple with named fields.
+
+ >>> Point = namedtuple('Point', ['x', 'y'])
+ >>> Point.__doc__ # docstring for the new class
+ 'Point(x, y)'
+ >>> p = Point(11, y=22) # instantiate with positional args or keywords
+ >>> p[0] + p[1] # indexable like a plain tuple
+ 33
+ >>> x, y = p # unpack like a regular tuple
+ >>> x, y
+ (11, 22)
+ >>> p.x + p.y # fields also accessible by name
+ 33
+ >>> d = p._asdict() # convert to a dictionary
+ >>> d['x']
+ 11
+ >>> Point(**d) # convert from a dictionary
+ Point(x=11, y=22)
+ >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
+ Point(x=100, y=22)
+
+ """
+
+ # Validate the field names. At the user's option, either generate an error
+ # message or automatically replace the field name with a valid name.
+ if isinstance(field_names, str):
+ field_names = field_names.replace(',', ' ').split()
+ field_names = list(map(str, field_names))
+ typename = _sys.intern(str(typename))
+
+ if rename:
+ seen = set()
+ for index, name in enumerate(field_names):
+ if (not name.isidentifier()
+ or _iskeyword(name)
+ or name.startswith('_')
+ or name in seen):
+ field_names[index] = f'_{index}'
+ seen.add(name)
+
+ for name in [typename] + field_names:
+ if type(name) is not str:
+ raise TypeError('Type names and field names must be strings')
+ if not name.isidentifier():
+ raise ValueError('Type names and field names must be valid '
+ f'identifiers: {name!r}')
+ if _iskeyword(name):
+ raise ValueError('Type names and field names cannot be a '
+ f'keyword: {name!r}')
+
+ seen = set()
+ for name in field_names:
+ if name.startswith('_') and not rename:
+ raise ValueError('Field names cannot start with an underscore: '
+ f'{name!r}')
+ if name in seen:
+ raise ValueError(f'Encountered duplicate field name: {name!r}')
+ seen.add(name)
+
+ field_defaults = {}
+ if defaults is not None:
+ defaults = tuple(defaults)
+ if len(defaults) > len(field_names):
+ raise TypeError('Got more default values than field names')
+ field_defaults = dict(reversed(list(zip(reversed(field_names),
+ reversed(defaults)))))
+
+ # Variables used in the methods and docstrings
+ field_names = tuple(map(_sys.intern, field_names))
+ num_fields = len(field_names)
+ arg_list = ', '.join(field_names)
+ if num_fields == 1:
+ arg_list += ','
+ repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')'
+ tuple_new = tuple.__new__
+ _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip
+
+ # Create all the named tuple methods to be added to the class namespace
+
+ namespace = {
+ '_tuple_new': tuple_new,
+ '__builtins__': {},
+ '__name__': f'namedtuple_{typename}',
+ }
+ code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))'
+ __new__ = eval(code, namespace)
+ __new__.__name__ = '__new__'
+ __new__.__doc__ = f'Create new instance of {typename}({arg_list})'
+ if defaults is not None:
+ __new__.__defaults__ = defaults
+
+ @classmethod
+ def _make(cls, iterable):
+ result = tuple_new(cls, iterable)
+ if _len(result) != num_fields:
+ raise TypeError(f'Expected {num_fields} arguments, got {len(result)}')
+ return result
+
+ _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence '
+ 'or iterable')
+
+ def _replace(self, /, **kwds):
+ result = self._make(_map(kwds.pop, field_names, self))
+ if kwds:
+ raise ValueError(f'Got unexpected field names: {list(kwds)!r}')
+ return result
+
+ _replace.__doc__ = (f'Return a new {typename} object replacing specified '
+ 'fields with new values')
+
+ def __repr__(self):
+ 'Return a nicely formatted representation string'
+ return self.__class__.__name__ + repr_fmt % self
+
+ def _asdict(self):
+ 'Return a new dict which maps field names to their values.'
+ return _dict(_zip(self._fields, self))
+
+ def __getnewargs__(self):
+ 'Return self as a plain tuple. Used by copy and pickle.'
+ return _tuple(self)
+
+ # Modify function metadata to help with introspection and debugging
+ for method in (
+ __new__,
+ _make.__func__,
+ _replace,
+ __repr__,
+ _asdict,
+ __getnewargs__,
+ ):
+ method.__qualname__ = f'{typename}.{method.__name__}'
+
+ # Build-up the class namespace dictionary
+ # and use type() to build the result class
+ class_namespace = {
+ '__doc__': f'{typename}({arg_list})',
+ '__slots__': (),
+ '_fields': field_names,
+ '_field_defaults': field_defaults,
+ '__new__': __new__,
+ '_make': _make,
+ '_replace': _replace,
+ '__repr__': __repr__,
+ '_asdict': _asdict,
+ '__getnewargs__': __getnewargs__,
+ '__match_args__': field_names,
+ }
+ for index, name in enumerate(field_names):
+ doc = _sys.intern(f'Alias for field number {index}')
+ class_namespace[name] = _tuplegetter(index, doc)
+
+ result = type(typename, (tuple,), class_namespace)
+
+ # For pickling to work, the __module__ variable needs to be set to the frame
+ # where the named tuple is created. Bypass this step in environments where
+ # sys._getframe is not defined (Jython for example) or sys._getframe is not
+ # defined for arguments greater than 0 (IronPython), or where the user has
+ # specified a particular module.
+ if module is None:
+ try:
+ module = _sys._getframe(1).f_globals.get('__name__', '__main__')
+ except (AttributeError, ValueError):
+ pass
+ if module is not None:
+ result.__module__ = module
+
+ return result
+
+
+SEG = namedtuple("SEG",
+ ['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'],
+ defaults=[None])
+
+def crop_ndarray4(npimg, crop_region):
+ x1 = crop_region[0]
+ y1 = crop_region[1]
+ x2 = crop_region[2]
+ y2 = crop_region[3]
+
+ cropped = npimg[:, y1:y2, x1:x2, :]
+
+ return cropped
+
+crop_tensor4 = crop_ndarray4
+
+def crop_ndarray2(npimg, crop_region):
+ x1 = crop_region[0]
+ y1 = crop_region[1]
+ x2 = crop_region[2]
+ y2 = crop_region[3]
+
+ cropped = npimg[y1:y2, x1:x2]
+
+ return cropped
+
+def crop_image(image, crop_region):
+ return crop_tensor4(image, crop_region)
+
+def normalize_region(limit, startp, size):
+ if startp < 0:
+ new_endp = min(limit, size)
+ new_startp = 0
+ elif startp + size > limit:
+ new_startp = max(0, limit - size)
+ new_endp = limit
+ else:
+ new_startp = startp
+ new_endp = min(limit, startp+size)
+
+ return int(new_startp), int(new_endp)
+
+def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None):
+ x1 = bbox[0]
+ y1 = bbox[1]
+ x2 = bbox[2]
+ y2 = bbox[3]
+
+ bbox_w = x2 - x1
+ bbox_h = y2 - y1
+
+ crop_w = bbox_w * crop_factor
+ crop_h = bbox_h * crop_factor
+
+ if crop_min_size is not None:
+ crop_w = max(crop_min_size, crop_w)
+ crop_h = max(crop_min_size, crop_h)
+
+ kernel_x = x1 + bbox_w / 2
+ kernel_y = y1 + bbox_h / 2
+
+ new_x1 = int(kernel_x - crop_w / 2)
+ new_y1 = int(kernel_y - crop_h / 2)
+
+ # make sure position in (w,h)
+ new_x1, new_x2 = normalize_region(w, new_x1, crop_w)
+ new_y1, new_y2 = normalize_region(h, new_y1, crop_h)
+
+ return [new_x1, new_y1, new_x2, new_y2]
+
+def create_segmasks(results):
+ bboxs = results[1]
+ segms = results[2]
+ confidence = results[3]
+
+ results = []
+ for i in range(len(segms)):
+ item = (bboxs[i], segms[i].astype(np.float32), confidence[i])
+ results.append(item)
+ return results
+
+def dilate_masks(segmasks, dilation_factor, iter=1):
+ if dilation_factor == 0:
+ return segmasks
+
+ dilated_masks = []
+ kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8)
+
+ kernel = cv2.UMat(kernel)
+
+ for i in range(len(segmasks)):
+ cv2_mask = segmasks[i][1]
+
+ cv2_mask = cv2.UMat(cv2_mask)
+
+ if dilation_factor > 0:
+ dilated_mask = cv2.dilate(cv2_mask, kernel, iter)
+ else:
+ dilated_mask = cv2.erode(cv2_mask, kernel, iter)
+
+ dilated_mask = dilated_mask.get()
+
+ item = (segmasks[i][0], dilated_mask, segmasks[i][2])
+ dilated_masks.append(item)
+
+ return dilated_masks
+
+def is_same_device(a, b):
+ a_device = torch.device(a) if isinstance(a, str) else a
+ b_device = torch.device(b) if isinstance(b, str) else b
+ return a_device.type == b_device.type and a_device.index == b_device.index
+
+class SafeToGPU:
+ def __init__(self, size):
+ self.size = size
+
+ def to_device(self, obj, device):
+ if is_same_device(device, 'cpu'):
+ obj.to(device)
+ else:
+ if is_same_device(obj.device, 'cpu'): # cpu to gpu
+ model_management.free_memory(self.size * 1.3, device)
+ if model_management.get_free_memory(device) > self.size * 1.3:
+ try:
+ obj.to(device)
+ except:
+ print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]")
+ else:
+ print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]")
+
+def center_of_bbox(bbox):
+ w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
+ return bbox[0] + w/2, bbox[1] + h/2
+
+def sam_predict(predictor, points, plabs, bbox, threshold):
+ point_coords = None if not points else np.array(points)
+ point_labels = None if not plabs else np.array(plabs)
+
+ box = np.array([bbox]) if bbox is not None else None
+
+ cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box)
+
+ total_masks = []
+
+ selected = False
+ max_score = 0
+ max_mask = None
+ for idx in range(len(scores)):
+ if scores[idx] > max_score:
+ max_score = scores[idx]
+ max_mask = cur_masks[idx]
+
+ if scores[idx] >= threshold:
+ selected = True
+ total_masks.append(cur_masks[idx])
+ else:
+ pass
+
+ if not selected and max_mask is not None:
+ total_masks.append(max_mask)
+
+ return total_masks
+
+def make_2d_mask(mask):
+ if len(mask.shape) == 4:
+ return mask.squeeze(0).squeeze(0)
+
+ elif len(mask.shape) == 3:
+ return mask.squeeze(0)
+
+ return mask
+
+def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative):
+ mask = make_2d_mask(mask)
+
+ points = []
+ plabs = []
+
+ # minimum sampling step >= 3
+ y_step = max(3, int(mask.shape[0] / 20))
+ x_step = max(3, int(mask.shape[1] / 20))
+
+ for i in range(0, len(mask), y_step):
+ for j in range(0, len(mask[i]), x_step):
+ if mask[i][j] > threshold:
+ points.append((x + j, y + i))
+ plabs.append(1)
+ elif use_negative and mask[i][j] == 0:
+ points.append((x + j, y + i))
+ plabs.append(0)
+
+ return points, plabs
+
+def gen_negative_hints(w, h, x1, y1, x2, y2):
+ npoints = []
+ nplabs = []
+
+ # minimum sampling step >= 3
+ y_step = max(3, int(w / 20))
+ x_step = max(3, int(h / 20))
+
+ for i in range(10, h - 10, y_step):
+ for j in range(10, w - 10, x_step):
+ if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10):
+ npoints.append((j, i))
+ nplabs.append(0)
+
+ return npoints, nplabs
+
+def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative,
+ mask_hint_use_negative):
+ [x1, y1, x2, y2] = dilated_bbox
+
+ points = []
+ plabs = []
+ if detection_hint == "center-1":
+ points.append(center)
+ plabs = [1] # 1 = foreground point, 0 = background point
+
+ elif detection_hint == "horizontal-2":
+ gap = (x2 - x1) / 3
+ points.append((x1 + gap, center[1]))
+ points.append((x1 + gap * 2, center[1]))
+ plabs = [1, 1]
+
+ elif detection_hint == "vertical-2":
+ gap = (y2 - y1) / 3
+ points.append((center[0], y1 + gap))
+ points.append((center[0], y1 + gap * 2))
+ plabs = [1, 1]
+
+ elif detection_hint == "rect-4":
+ x_gap = (x2 - x1) / 3
+ y_gap = (y2 - y1) / 3
+ points.append((x1 + x_gap, center[1]))
+ points.append((x1 + x_gap * 2, center[1]))
+ points.append((center[0], y1 + y_gap))
+ points.append((center[0], y1 + y_gap * 2))
+ plabs = [1, 1, 1, 1]
+
+ elif detection_hint == "diamond-4":
+ x_gap = (x2 - x1) / 3
+ y_gap = (y2 - y1) / 3
+ points.append((x1 + x_gap, y1 + y_gap))
+ points.append((x1 + x_gap * 2, y1 + y_gap))
+ points.append((x1 + x_gap, y1 + y_gap * 2))
+ points.append((x1 + x_gap * 2, y1 + y_gap * 2))
+ plabs = [1, 1, 1, 1]
+
+ elif detection_hint == "mask-point-bbox":
+ center = center_of_bbox(seg.bbox)
+ points.append(center)
+ plabs = [1]
+
+ elif detection_hint == "mask-area":
+ points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1],
+ seg.cropped_mask,
+ mask_hint_threshold, use_small_negative)
+
+ if mask_hint_use_negative == "Outter":
+ npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1],
+ seg.crop_region[0], seg.crop_region[1],
+ seg.crop_region[2], seg.crop_region[3])
+
+ points += npoints
+ plabs += nplabs
+
+ return points, plabs
+
+def combine_masks2(masks):
+ if len(masks) == 0:
+ return None
+ else:
+ initial_cv2_mask = np.array(masks[0]).astype(np.uint8)
+ combined_cv2_mask = initial_cv2_mask
+
+ for i in range(1, len(masks)):
+ cv2_mask = np.array(masks[i]).astype(np.uint8)
+
+ if combined_cv2_mask.shape == cv2_mask.shape:
+ combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask)
+ else:
+ # do nothing - incompatible mask
+ pass
+
+ mask = torch.from_numpy(combined_cv2_mask)
+ return mask
+
+def dilate_mask(mask, dilation_factor, iter=1):
+ if dilation_factor == 0:
+ return make_2d_mask(mask)
+
+ mask = make_2d_mask(mask)
+
+ kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8)
+
+ mask = cv2.UMat(mask)
+ kernel = cv2.UMat(kernel)
+
+ if dilation_factor > 0:
+ result = cv2.dilate(mask, kernel, iter)
+ else:
+ result = cv2.erode(mask, kernel, iter)
+
+ return result.get()
+
+def convert_and_stack_masks(masks):
+ if len(masks) == 0:
+ return None
+
+ mask_tensors = []
+ for mask in masks:
+ mask_array = np.array(mask, dtype=np.uint8)
+ mask_tensor = torch.from_numpy(mask_array)
+ mask_tensors.append(mask_tensor)
+
+ stacked_masks = torch.stack(mask_tensors, dim=0)
+ stacked_masks = stacked_masks.unsqueeze(1)
+
+ return stacked_masks
+
+def merge_and_stack_masks(stacked_masks, group_size):
+ if stacked_masks is None:
+ return None
+
+ num_masks = stacked_masks.size(0)
+ merged_masks = []
+
+ for i in range(0, num_masks, group_size):
+ subset_masks = stacked_masks[i:i + group_size]
+ merged_mask = torch.any(subset_masks, dim=0)
+ merged_masks.append(merged_mask)
+
+ if len(merged_masks) > 0:
+ merged_masks = torch.stack(merged_masks, dim=0)
+
+ return merged_masks
+
+def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation,
+ threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative):
+ if sam_model.is_auto_mode:
+ device = model_management.get_torch_device()
+ sam_model.safe_to.to_device(sam_model, device=device)
+
+ try:
+ predictor = SamPredictor(sam_model)
+ image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
+ predictor.set_image(image, "RGB")
+
+ total_masks = []
+
+ use_small_negative = mask_hint_use_negative == "Small"
+
+ # seg_shape = segs[0]
+ segs = segs[1]
+ if detection_hint == "mask-points":
+ points = []
+ plabs = []
+
+ for i in range(len(segs)):
+ bbox = segs[i].bbox
+ center = center_of_bbox(bbox)
+ points.append(center)
+
+ # small point is background, big point is foreground
+ if use_small_negative and bbox[2] - bbox[0] < 10:
+ plabs.append(0)
+ else:
+ plabs.append(1)
+
+ detected_masks = sam_predict(predictor, points, plabs, None, threshold)
+ total_masks += detected_masks
+
+ else:
+ for i in range(len(segs)):
+ bbox = segs[i].bbox
+ center = center_of_bbox(bbox)
+ x1 = max(bbox[0] - bbox_expansion, 0)
+ y1 = max(bbox[1] - bbox_expansion, 0)
+ x2 = min(bbox[2] + bbox_expansion, image.shape[1])
+ y2 = min(bbox[3] + bbox_expansion, image.shape[0])
+
+ dilated_bbox = [x1, y1, x2, y2]
+
+ points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox,
+ mask_hint_threshold, use_small_negative,
+ mask_hint_use_negative)
+
+ detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold)
+
+ total_masks += detected_masks
+
+ # merge every collected masks
+ mask = combine_masks2(total_masks)
+
+ finally:
+ if sam_model.is_auto_mode:
+ sam_model.cpu()
+
+ pass
+
+ mask_working_device = torch.device("cpu")
+
+ if mask is not None:
+ mask = mask.float()
+ mask = dilate_mask(mask.cpu().numpy(), dilation)
+ mask = torch.from_numpy(mask)
+ mask = mask.to(device=mask_working_device)
+ else:
+ # Extracting batch, height and width
+ height, width, _ = image.shape
+ mask = torch.zeros(
+ (height, width), dtype=torch.float32, device=mask_working_device
+ ) # empty mask
+
+ stacked_masks = convert_and_stack_masks(total_masks)
+
+ return (mask, merge_and_stack_masks(stacked_masks, group_size=3))
+
+def tensor2mask(t: torch.Tensor) -> torch.Tensor:
+ size = t.size()
+ if (len(size) < 4):
+ return t
+ if size[3] == 1:
+ return t[:,:,:,0]
+ elif size[3] == 4:
+ # Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior
+ if torch.min(t[:, :, :, 3]).item() != 1.:
+ return t[:,:,:,3]
+ return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:]
+
+def tensor2rgb(t: torch.Tensor) -> torch.Tensor:
+ size = t.size()
+ if (len(size) < 4):
+ return t.unsqueeze(3).repeat(1, 1, 1, 3)
+ if size[3] == 1:
+ return t.repeat(1, 1, 1, 3)
+ elif size[3] == 4:
+ return t[:, :, :, :3]
+ else:
+ return t
+
+def tensor2rgba(t: torch.Tensor) -> torch.Tensor:
+ size = t.size()
+ if (len(size) < 4):
+ return t.unsqueeze(3).repeat(1, 1, 1, 4)
+ elif size[3] == 1:
+ return t.repeat(1, 1, 1, 4)
+ elif size[3] == 3:
+ alpha_tensor = torch.ones((size[0], size[1], size[2], 1))
+ return torch.cat((t, alpha_tensor), dim=3)
+ else:
+ return t
diff --git a/scripts/r_masking/segs.py b/scripts/r_masking/segs.py new file mode 100644 index 0000000..60c22d7 --- /dev/null +++ b/scripts/r_masking/segs.py @@ -0,0 +1,22 @@ +def filter(segs, labels):
+ labels = set([label.strip() for label in labels])
+
+ if 'all' in labels:
+ return (segs, (segs[0], []), )
+ else:
+ res_segs = []
+ remained_segs = []
+
+ for x in segs[1]:
+ if x.label in labels:
+ res_segs.append(x)
+ elif 'eyes' in labels and x.label in ['left_eye', 'right_eye']:
+ res_segs.append(x)
+ elif 'eyebrows' in labels and x.label in ['left_eyebrow', 'right_eyebrow']:
+ res_segs.append(x)
+ elif 'pupils' in labels and x.label in ['left_pupil', 'right_pupil']:
+ res_segs.append(x)
+ else:
+ remained_segs.append(x)
+
+ return ((segs[0], res_segs), (segs[0], remained_segs), )
diff --git a/scripts/r_masking/subcore.py b/scripts/r_masking/subcore.py new file mode 100644 index 0000000..cf7bf7d --- /dev/null +++ b/scripts/r_masking/subcore.py @@ -0,0 +1,117 @@ +import numpy as np
+import cv2
+from PIL import Image
+
+import scripts.r_masking.core as core
+from reactor_utils import tensor_to_pil
+
+try:
+ from ultralytics import YOLO
+except Exception as e:
+ print(e)
+
+
+def load_yolo(model_path: str):
+ try:
+ return YOLO(model_path)
+ except ModuleNotFoundError:
+ # https://github.com/ultralytics/ultralytics/issues/3856
+ YOLO("yolov8n.pt")
+ return YOLO(model_path)
+
+def inference_bbox(
+ model,
+ image: Image.Image,
+ confidence: float = 0.3,
+ device: str = "",
+):
+ pred = model(image, conf=confidence, device=device)
+
+ bboxes = pred[0].boxes.xyxy.cpu().numpy()
+ cv2_image = np.array(image)
+ if len(cv2_image.shape) == 3:
+ cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing
+ else:
+ # Handle the grayscale image here
+ # For example, you might want to convert it to a 3-channel grayscale image for consistency:
+ cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_GRAY2BGR)
+ cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
+
+ segms = []
+ for x0, y0, x1, y1 in bboxes:
+ cv2_mask = np.zeros(cv2_gray.shape, np.uint8)
+ cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
+ cv2_mask_bool = cv2_mask.astype(bool)
+ segms.append(cv2_mask_bool)
+
+ n, m = bboxes.shape
+ if n == 0:
+ return [[], [], [], []]
+
+ results = [[], [], [], []]
+ for i in range(len(bboxes)):
+ results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
+ results[1].append(bboxes[i])
+ results[2].append(segms[i])
+ results[3].append(pred[0].boxes[i].conf.cpu().numpy())
+
+ return results
+
+
+class UltraBBoxDetector:
+ bbox_model = None
+
+ def __init__(self, bbox_model):
+ self.bbox_model = bbox_model
+
+ def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
+ drop_size = max(drop_size, 1)
+ detected_results = inference_bbox(self.bbox_model, tensor_to_pil(image), threshold)
+ segmasks = core.create_segmasks(detected_results)
+
+ if dilation > 0:
+ segmasks = core.dilate_masks(segmasks, dilation)
+
+ items = []
+ h = image.shape[1]
+ w = image.shape[2]
+
+ for x, label in zip(segmasks, detected_results[0]):
+ item_bbox = x[0]
+ item_mask = x[1]
+
+ y1, x1, y2, x2 = item_bbox
+
+ if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue
+ crop_region = core.make_crop_region(w, h, item_bbox, crop_factor)
+
+ if detailer_hook is not None:
+ crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region)
+
+ cropped_image = core.crop_image(image, crop_region)
+ cropped_mask = core.crop_ndarray2(item_mask, crop_region)
+ confidence = x[2]
+ # bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h)
+
+ item = core.SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None)
+
+ items.append(item)
+
+ shape = image.shape[1], image.shape[2]
+ segs = shape, items
+
+ if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
+ segs = detailer_hook.post_detection(segs)
+
+ return segs
+
+ def detect_combined(self, image, threshold, dilation):
+ detected_results = inference_bbox(self.bbox_model, core.tensor2pil(image), threshold)
+ segmasks = core.create_segmasks(detected_results)
+ if dilation > 0:
+ segmasks = core.dilate_masks(segmasks, dilation)
+
+ return core.combine_masks(segmasks)
+
+ def setAux(self, x):
+ pass
diff --git a/scripts/reactor_faceswap.py b/scripts/reactor_faceswap.py new file mode 100644 index 0000000..7e6c03e --- /dev/null +++ b/scripts/reactor_faceswap.py @@ -0,0 +1,185 @@ +import os, glob
+
+from PIL import Image
+
+import modules.scripts as scripts
+# from modules.upscaler import Upscaler, UpscalerData
+from modules import scripts, scripts_postprocessing
+from modules.processing import (
+ StableDiffusionProcessing,
+ StableDiffusionProcessingImg2Img,
+)
+from modules.shared import state
+from scripts.reactor_logger import logger
+from scripts.reactor_swapper import (
+ swap_face,
+ swap_face_many,
+ get_current_faces_model,
+ analyze_faces,
+ half_det_size,
+ providers
+)
+import folder_paths
+import comfy.model_management as model_management
+
+
+def get_models():
+ models_path = os.path.join(folder_paths.models_dir,"insightface/*")
+ models = glob.glob(models_path)
+ models = [x for x in models if x.endswith(".onnx") or x.endswith(".pth")]
+ return models
+
+
+class FaceSwapScript(scripts.Script):
+
+ def process(
+ self,
+ p: StableDiffusionProcessing,
+ img,
+ enable,
+ source_faces_index,
+ faces_index,
+ model,
+ swap_in_source,
+ swap_in_generated,
+ gender_source,
+ gender_target,
+ face_model,
+ faces_order,
+ face_boost_enabled,
+ face_restore_model,
+ face_restore_visibility,
+ codeformer_weight,
+ interpolation,
+ ):
+ self.enable = enable
+ if self.enable:
+
+ self.source = img
+ self.swap_in_generated = swap_in_generated
+ self.gender_source = gender_source
+ self.gender_target = gender_target
+ self.model = model
+ self.face_model = face_model
+ self.faces_order = faces_order
+ self.face_boost_enabled = face_boost_enabled
+ self.face_restore_model = face_restore_model
+ self.face_restore_visibility = face_restore_visibility
+ self.codeformer_weight = codeformer_weight
+ self.interpolation = interpolation
+ self.source_faces_index = [
+ int(x) for x in source_faces_index.strip(",").split(",") if x.isnumeric()
+ ]
+ self.faces_index = [
+ int(x) for x in faces_index.strip(",").split(",") if x.isnumeric()
+ ]
+ if len(self.source_faces_index) == 0:
+ self.source_faces_index = [0]
+ if len(self.faces_index) == 0:
+ self.faces_index = [0]
+
+ if self.gender_source is None or self.gender_source == "no":
+ self.gender_source = 0
+ elif self.gender_source == "female":
+ self.gender_source = 1
+ elif self.gender_source == "male":
+ self.gender_source = 2
+
+ if self.gender_target is None or self.gender_target == "no":
+ self.gender_target = 0
+ elif self.gender_target == "female":
+ self.gender_target = 1
+ elif self.gender_target == "male":
+ self.gender_target = 2
+
+ # if self.source is not None:
+ if isinstance(p, StableDiffusionProcessingImg2Img) and swap_in_source:
+ logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index)
+
+ if len(p.init_images) == 1:
+
+ result = swap_face(
+ self.source,
+ p.init_images[0],
+ source_faces_index=self.source_faces_index,
+ faces_index=self.faces_index,
+ model=self.model,
+ gender_source=self.gender_source,
+ gender_target=self.gender_target,
+ face_model=self.face_model,
+ faces_order=self.faces_order,
+ face_boost_enabled=self.face_boost_enabled,
+ face_restore_model=self.face_restore_model,
+ face_restore_visibility=self.face_restore_visibility,
+ codeformer_weight=self.codeformer_weight,
+ interpolation=self.interpolation,
+ )
+ p.init_images[0] = result
+
+ # for i in range(len(p.init_images)):
+ # if state.interrupted or model_management.processing_interrupted():
+ # logger.status("Interrupted by User")
+ # break
+ # if len(p.init_images) > 1:
+ # logger.status(f"Swap in %s", i)
+ # result = swap_face(
+ # self.source,
+ # p.init_images[i],
+ # source_faces_index=self.source_faces_index,
+ # faces_index=self.faces_index,
+ # model=self.model,
+ # gender_source=self.gender_source,
+ # gender_target=self.gender_target,
+ # face_model=self.face_model,
+ # )
+ # p.init_images[i] = result
+
+ elif len(p.init_images) > 1:
+ result = swap_face_many(
+ self.source,
+ p.init_images,
+ source_faces_index=self.source_faces_index,
+ faces_index=self.faces_index,
+ model=self.model,
+ gender_source=self.gender_source,
+ gender_target=self.gender_target,
+ face_model=self.face_model,
+ faces_order=self.faces_order,
+ face_boost_enabled=self.face_boost_enabled,
+ face_restore_model=self.face_restore_model,
+ face_restore_visibility=self.face_restore_visibility,
+ codeformer_weight=self.codeformer_weight,
+ interpolation=self.interpolation,
+ )
+ p.init_images = result
+
+ logger.status("--Done!--")
+ # else:
+ # logger.error(f"Please provide a source face")
+
+ def postprocess_batch(self, p, *args, **kwargs):
+ if self.enable:
+ images = kwargs["images"]
+
+ def postprocess_image(self, p, script_pp: scripts.PostprocessImageArgs, *args):
+ if self.enable and self.swap_in_generated:
+ if self.source is not None:
+ logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index)
+ image: Image.Image = script_pp.image
+ result = swap_face(
+ self.source,
+ image,
+ source_faces_index=self.source_faces_index,
+ faces_index=self.faces_index,
+ model=self.model,
+ upscale_options=self.upscale_options,
+ gender_source=self.gender_source,
+ gender_target=self.gender_target,
+ )
+ try:
+ pp = scripts_postprocessing.PostprocessedImage(result)
+ pp.info = {}
+ p.extra_generation_params.update(pp.info)
+ script_pp.image = pp.image
+ except:
+ logger.error(f"Cannot create a result image")
diff --git a/scripts/reactor_logger.py b/scripts/reactor_logger.py new file mode 100644 index 0000000..f64e433 --- /dev/null +++ b/scripts/reactor_logger.py @@ -0,0 +1,47 @@ +import logging
+import copy
+import sys
+
+from modules import shared
+from reactor_utils import addLoggingLevel
+
+
+class ColoredFormatter(logging.Formatter):
+ COLORS = {
+ "DEBUG": "\033[0;36m", # CYAN
+ "STATUS": "\033[38;5;173m", # Calm ORANGE
+ "INFO": "\033[0;32m", # GREEN
+ "WARNING": "\033[0;33m", # YELLOW
+ "ERROR": "\033[0;31m", # RED
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
+ "RESET": "\033[0m", # RESET COLOR
+ }
+
+ def format(self, record):
+ colored_record = copy.copy(record)
+ levelname = colored_record.levelname
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
+ return super().format(colored_record)
+
+
+# Create a new logger
+logger = logging.getLogger("ReActor")
+logger.propagate = False
+
+# Add Custom Level
+# logging.addLevelName(logging.INFO, "STATUS")
+addLoggingLevel("STATUS", logging.INFO + 5)
+
+# Add handler if we don't have one.
+if not logger.handlers:
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(
+ ColoredFormatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s",datefmt="%H:%M:%S")
+ )
+ logger.addHandler(handler)
+
+# Configure logger
+loglevel_string = getattr(shared.cmd_opts, "reactor_loglevel", "INFO")
+loglevel = getattr(logging, loglevel_string.upper(), "info")
+logger.setLevel(loglevel)
diff --git a/scripts/reactor_swapper.py b/scripts/reactor_swapper.py new file mode 100644 index 0000000..6db5dfc --- /dev/null +++ b/scripts/reactor_swapper.py @@ -0,0 +1,572 @@ +import os
+import shutil
+from typing import List, Union
+
+import cv2
+import numpy as np
+from PIL import Image
+
+import insightface
+from insightface.app.common import Face
+# try:
+# import torch.cuda as cuda
+# except:
+# cuda = None
+import torch
+
+import folder_paths
+import comfy.model_management as model_management
+from modules.shared import state
+
+from scripts.reactor_logger import logger
+from reactor_utils import (
+ move_path,
+ get_image_md5hash,
+)
+from scripts.r_faceboost import swapper, restorer
+
+import warnings
+
+np.warnings = warnings
+np.warnings.filterwarnings('ignore')
+
+# PROVIDERS
+try:
+ if torch.cuda.is_available():
+ providers = ["CUDAExecutionProvider"]
+ elif torch.backends.mps.is_available():
+ providers = ["CoreMLExecutionProvider"]
+ elif hasattr(torch,'dml') or hasattr(torch,'privateuseone'):
+ providers = ["ROCMExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+except Exception as e:
+ logger.debug(f"ExecutionProviderError: {e}.\nEP is set to CPU.")
+ providers = ["CPUExecutionProvider"]
+# if cuda is not None:
+# if cuda.is_available():
+# providers = ["CUDAExecutionProvider"]
+# else:
+# providers = ["CPUExecutionProvider"]
+# else:
+# providers = ["CPUExecutionProvider"]
+
+models_path_old = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
+insightface_path_old = os.path.join(models_path_old, "insightface")
+insightface_models_path_old = os.path.join(insightface_path_old, "models")
+
+models_path = folder_paths.models_dir
+insightface_path = os.path.join(models_path, "insightface")
+insightface_models_path = os.path.join(insightface_path, "models")
+
+if os.path.exists(models_path_old):
+ move_path(insightface_models_path_old, insightface_models_path)
+ move_path(insightface_path_old, insightface_path)
+ move_path(models_path_old, models_path)
+if os.path.exists(insightface_path) and os.path.exists(insightface_path_old):
+ shutil.rmtree(insightface_path_old)
+ shutil.rmtree(models_path_old)
+
+
+FS_MODEL = None
+CURRENT_FS_MODEL_PATH = None
+
+ANALYSIS_MODELS = {
+ "640": None,
+ "320": None,
+}
+
+SOURCE_FACES = None
+SOURCE_IMAGE_HASH = None
+TARGET_FACES = None
+TARGET_IMAGE_HASH = None
+TARGET_FACES_LIST = []
+TARGET_IMAGE_LIST_HASH = []
+
+def unload_model(model):
+ if model is not None:
+ # check if model has unload method
+ # if "unload" in model:
+ # model.unload()
+ # if "model_unload" in model:
+ # model.model_unload()
+ del model
+ return None
+
+def unload_all_models():
+ global FS_MODEL, CURRENT_FS_MODEL_PATH
+ FS_MODEL = unload_model(FS_MODEL)
+ ANALYSIS_MODELS["320"] = unload_model(ANALYSIS_MODELS["320"])
+ ANALYSIS_MODELS["640"] = unload_model(ANALYSIS_MODELS["640"])
+
+def get_current_faces_model():
+ global SOURCE_FACES
+ return SOURCE_FACES
+
+def getAnalysisModel(det_size = (640, 640)):
+ global ANALYSIS_MODELS
+ ANALYSIS_MODEL = ANALYSIS_MODELS[str(det_size[0])]
+ if ANALYSIS_MODEL is None:
+ ANALYSIS_MODEL = insightface.app.FaceAnalysis(
+ name="buffalo_l", providers=providers, root=insightface_path
+ )
+ ANALYSIS_MODEL.prepare(ctx_id=0, det_size=det_size)
+ ANALYSIS_MODELS[str(det_size[0])] = ANALYSIS_MODEL
+ return ANALYSIS_MODEL
+
+def getFaceSwapModel(model_path: str):
+ global FS_MODEL, CURRENT_FS_MODEL_PATH
+ if FS_MODEL is None or CURRENT_FS_MODEL_PATH is None or CURRENT_FS_MODEL_PATH != model_path:
+ CURRENT_FS_MODEL_PATH = model_path
+ FS_MODEL = unload_model(FS_MODEL)
+ FS_MODEL = insightface.model_zoo.get_model(model_path, providers=providers)
+
+ return FS_MODEL
+
+
+def sort_by_order(face, order: str):
+ if order == "left-right":
+ return sorted(face, key=lambda x: x.bbox[0])
+ if order == "right-left":
+ return sorted(face, key=lambda x: x.bbox[0], reverse = True)
+ if order == "top-bottom":
+ return sorted(face, key=lambda x: x.bbox[1])
+ if order == "bottom-top":
+ return sorted(face, key=lambda x: x.bbox[1], reverse = True)
+ if order == "small-large":
+ return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
+ # if order == "large-small":
+ # return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)
+ # by default "large-small":
+ return sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)
+
+def get_face_gender(
+ face,
+ face_index,
+ gender_condition,
+ operated: str,
+ order: str,
+):
+ gender = [
+ x.sex
+ for x in face
+ ]
+ gender.reverse()
+ # If index is outside of bounds, return None, avoid exception
+ if face_index >= len(gender):
+ logger.status("Requested face index (%s) is out of bounds (max available index is %s)", face_index, len(gender))
+ return None, 0
+ face_gender = gender[face_index]
+ logger.status("%s Face %s: Detected Gender -%s-", operated, face_index, face_gender)
+ if (gender_condition == 1 and face_gender == "F") or (gender_condition == 2 and face_gender == "M"):
+ logger.status("OK - Detected Gender matches Condition")
+ try:
+ faces_sorted = sort_by_order(face, order)
+ return faces_sorted[face_index], 0
+ # return sorted(face, key=lambda x: x.bbox[0])[face_index], 0
+ except IndexError:
+ return None, 0
+ else:
+ logger.status("WRONG - Detected Gender doesn't match Condition")
+ faces_sorted = sort_by_order(face, order)
+ return faces_sorted[face_index], 1
+ # return sorted(face, key=lambda x: x.bbox[0])[face_index], 1
+
+def half_det_size(det_size):
+ logger.status("Trying to halve 'det_size' parameter")
+ return (det_size[0] // 2, det_size[1] // 2)
+
+def analyze_faces(img_data: np.ndarray, det_size=(640, 640)):
+ face_analyser = getAnalysisModel(det_size)
+ faces = face_analyser.get(img_data)
+
+ # Try halving det_size if no faces are found
+ if len(faces) == 0 and det_size[0] > 320 and det_size[1] > 320:
+ det_size_half = half_det_size(det_size)
+ return analyze_faces(img_data, det_size_half)
+
+ return faces
+
+def get_face_single(img_data: np.ndarray, face, face_index=0, det_size=(640, 640), gender_source=0, gender_target=0, order="large-small"):
+
+ buffalo_path = os.path.join(insightface_models_path, "buffalo_l.zip")
+ if os.path.exists(buffalo_path):
+ os.remove(buffalo_path)
+
+ if gender_source != 0:
+ if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
+ det_size_half = half_det_size(det_size)
+ return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order)
+ return get_face_gender(face,face_index,gender_source,"Source", order)
+
+ if gender_target != 0:
+ if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
+ det_size_half = half_det_size(det_size)
+ return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order)
+ return get_face_gender(face,face_index,gender_target,"Target", order)
+
+ if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
+ det_size_half = half_det_size(det_size)
+ return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target, order)
+
+ try:
+ faces_sorted = sort_by_order(face, order)
+ return faces_sorted[face_index], 0
+ # return sorted(face, key=lambda x: x.bbox[0])[face_index], 0
+ except IndexError:
+ return None, 0
+
+
+def swap_face(
+ source_img: Union[Image.Image, None],
+ target_img: Image.Image,
+ model: Union[str, None] = None,
+ source_faces_index: List[int] = [0],
+ faces_index: List[int] = [0],
+ gender_source: int = 0,
+ gender_target: int = 0,
+ face_model: Union[Face, None] = None,
+ faces_order: List = ["large-small", "large-small"],
+ face_boost_enabled: bool = False,
+ face_restore_model = None,
+ face_restore_visibility: int = 1,
+ codeformer_weight: float = 0.5,
+ interpolation: str = "Bicubic",
+):
+ global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH
+ result_image = target_img
+
+ if model is not None:
+
+ if isinstance(source_img, str): # source_img is a base64 string
+ import base64, io
+ if 'base64,' in source_img: # check if the base64 string has a data URL scheme
+ # split the base64 string to get the actual base64 encoded image data
+ base64_data = source_img.split('base64,')[-1]
+ # decode base64 string to bytes
+ img_bytes = base64.b64decode(base64_data)
+ else:
+ # if no data URL scheme, just decode
+ img_bytes = base64.b64decode(source_img)
+
+ source_img = Image.open(io.BytesIO(img_bytes))
+
+ target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
+
+ if source_img is not None:
+
+ source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
+
+ source_image_md5hash = get_image_md5hash(source_img)
+
+ if SOURCE_IMAGE_HASH is None:
+ SOURCE_IMAGE_HASH = source_image_md5hash
+ source_image_same = False
+ else:
+ source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False
+ if not source_image_same:
+ SOURCE_IMAGE_HASH = source_image_md5hash
+
+ logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH)
+ logger.info("Source Image the Same? %s", source_image_same)
+
+ if SOURCE_FACES is None or not source_image_same:
+ logger.status("Analyzing Source Image...")
+ source_faces = analyze_faces(source_img)
+ SOURCE_FACES = source_faces
+ elif source_image_same:
+ logger.status("Using Hashed Source Face(s) Model...")
+ source_faces = SOURCE_FACES
+
+ elif face_model is not None:
+
+ source_faces_index = [0]
+ logger.status("Using Loaded Source Face Model...")
+ source_face_model = [face_model]
+ source_faces = source_face_model
+
+ else:
+ logger.error("Cannot detect any Source")
+
+ if source_faces is not None:
+
+ target_image_md5hash = get_image_md5hash(target_img)
+
+ if TARGET_IMAGE_HASH is None:
+ TARGET_IMAGE_HASH = target_image_md5hash
+ target_image_same = False
+ else:
+ target_image_same = True if TARGET_IMAGE_HASH == target_image_md5hash else False
+ if not target_image_same:
+ TARGET_IMAGE_HASH = target_image_md5hash
+
+ logger.info("Target Image MD5 Hash = %s", TARGET_IMAGE_HASH)
+ logger.info("Target Image the Same? %s", target_image_same)
+
+ if TARGET_FACES is None or not target_image_same:
+ logger.status("Analyzing Target Image...")
+ target_faces = analyze_faces(target_img)
+ TARGET_FACES = target_faces
+ elif target_image_same:
+ logger.status("Using Hashed Target Face(s) Model...")
+ target_faces = TARGET_FACES
+
+ # No use in trying to swap faces if no faces are found, enhancement
+ if len(target_faces) == 0:
+ logger.status("Cannot detect any Target, skipping swapping...")
+ return result_image
+
+ if source_img is not None:
+ # separated management of wrong_gender between source and target, enhancement
+ source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source, order=faces_order[1])
+ else:
+ # source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]]
+ source_face = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)[source_faces_index[0]]
+ src_wrong_gender = 0
+
+ if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index):
+ logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.')
+ elif source_face is not None:
+ result = target_img
+ model_path = model_path = os.path.join(insightface_path, model)
+ face_swapper = getFaceSwapModel(model_path)
+
+ source_face_idx = 0
+
+ for face_num in faces_index:
+ # No use in trying to swap faces if no further faces are found, enhancement
+ if face_num >= len(target_faces):
+ logger.status("Checked all existing target faces, skipping swapping...")
+ break
+
+ if len(source_faces_index) > 1 and source_face_idx > 0:
+ source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source, order=faces_order[1])
+ source_face_idx += 1
+
+ if source_face is not None and src_wrong_gender == 0:
+ target_face, wrong_gender = get_face_single(target_img, target_faces, face_index=face_num, gender_target=gender_target, order=faces_order[0])
+ if target_face is not None and wrong_gender == 0:
+ logger.status(f"Swapping...")
+ if face_boost_enabled:
+ logger.status(f"Face Boost is enabled")
+ bgr_fake, M = face_swapper.get(result, target_face, source_face, paste_back=False)
+ bgr_fake, scale = restorer.get_restored_face(bgr_fake, face_restore_model, face_restore_visibility, codeformer_weight, interpolation)
+ M *= scale
+ result = swapper.in_swap(target_img, bgr_fake, M)
+ else:
+ # logger.status(f"Swapping as-is")
+ result = face_swapper.get(result, target_face, source_face)
+ elif wrong_gender == 1:
+ wrong_gender = 0
+ # Keep searching for other faces if wrong gender is detected, enhancement
+ #if source_face_idx == len(source_faces_index):
+ # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
+ # return result_image
+ logger.status("Wrong target gender detected")
+ continue
+ else:
+ logger.status(f"No target face found for {face_num}")
+ elif src_wrong_gender == 1:
+ src_wrong_gender = 0
+ # Keep searching for other faces if wrong gender is detected, enhancement
+ #if source_face_idx == len(source_faces_index):
+ # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
+ # return result_image
+ logger.status("Wrong source gender detected")
+ continue
+ else:
+ logger.status(f"No source face found for face number {source_face_idx}.")
+
+ result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
+
+ else:
+ logger.status("No source face(s) in the provided Index")
+ else:
+ logger.status("No source face(s) found")
+ return result_image
+
+def swap_face_many(
+ source_img: Union[Image.Image, None],
+ target_imgs: List[Image.Image],
+ model: Union[str, None] = None,
+ source_faces_index: List[int] = [0],
+ faces_index: List[int] = [0],
+ gender_source: int = 0,
+ gender_target: int = 0,
+ face_model: Union[Face, None] = None,
+ faces_order: List = ["large-small", "large-small"],
+ face_boost_enabled: bool = False,
+ face_restore_model = None,
+ face_restore_visibility: int = 1,
+ codeformer_weight: float = 0.5,
+ interpolation: str = "Bicubic",
+):
+ global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH, TARGET_FACES_LIST, TARGET_IMAGE_LIST_HASH
+ result_images = target_imgs
+
+ if model is not None:
+
+ if isinstance(source_img, str): # source_img is a base64 string
+ import base64, io
+ if 'base64,' in source_img: # check if the base64 string has a data URL scheme
+ # split the base64 string to get the actual base64 encoded image data
+ base64_data = source_img.split('base64,')[-1]
+ # decode base64 string to bytes
+ img_bytes = base64.b64decode(base64_data)
+ else:
+ # if no data URL scheme, just decode
+ img_bytes = base64.b64decode(source_img)
+
+ source_img = Image.open(io.BytesIO(img_bytes))
+
+ target_imgs = [cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) for target_img in target_imgs]
+
+ if source_img is not None:
+
+ source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
+
+ source_image_md5hash = get_image_md5hash(source_img)
+
+ if SOURCE_IMAGE_HASH is None:
+ SOURCE_IMAGE_HASH = source_image_md5hash
+ source_image_same = False
+ else:
+ source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False
+ if not source_image_same:
+ SOURCE_IMAGE_HASH = source_image_md5hash
+
+ logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH)
+ logger.info("Source Image the Same? %s", source_image_same)
+
+ if SOURCE_FACES is None or not source_image_same:
+ logger.status("Analyzing Source Image...")
+ source_faces = analyze_faces(source_img)
+ SOURCE_FACES = source_faces
+ elif source_image_same:
+ logger.status("Using Hashed Source Face(s) Model...")
+ source_faces = SOURCE_FACES
+
+ elif face_model is not None:
+
+ source_faces_index = [0]
+ logger.status("Using Loaded Source Face Model...")
+ source_face_model = [face_model]
+ source_faces = source_face_model
+
+ else:
+ logger.error("Cannot detect any Source")
+
+ if source_faces is not None:
+
+ target_faces = []
+ for i, target_img in enumerate(target_imgs):
+ if state.interrupted or model_management.processing_interrupted():
+ logger.status("Interrupted by User")
+ break
+
+ target_image_md5hash = get_image_md5hash(target_img)
+ if len(TARGET_IMAGE_LIST_HASH) == 0:
+ TARGET_IMAGE_LIST_HASH = [target_image_md5hash]
+ target_image_same = False
+ elif len(TARGET_IMAGE_LIST_HASH) == i:
+ TARGET_IMAGE_LIST_HASH.append(target_image_md5hash)
+ target_image_same = False
+ else:
+ target_image_same = True if TARGET_IMAGE_LIST_HASH[i] == target_image_md5hash else False
+ if not target_image_same:
+ TARGET_IMAGE_LIST_HASH[i] = target_image_md5hash
+
+ logger.info("(Image %s) Target Image MD5 Hash = %s", i, TARGET_IMAGE_LIST_HASH[i])
+ logger.info("(Image %s) Target Image the Same? %s", i, target_image_same)
+
+ if len(TARGET_FACES_LIST) == 0:
+ logger.status(f"Analyzing Target Image {i}...")
+ target_face = analyze_faces(target_img)
+ TARGET_FACES_LIST = [target_face]
+ elif len(TARGET_FACES_LIST) == i and not target_image_same:
+ logger.status(f"Analyzing Target Image {i}...")
+ target_face = analyze_faces(target_img)
+ TARGET_FACES_LIST.append(target_face)
+ elif len(TARGET_FACES_LIST) != i and not target_image_same:
+ logger.status(f"Analyzing Target Image {i}...")
+ target_face = analyze_faces(target_img)
+ TARGET_FACES_LIST[i] = target_face
+ elif target_image_same:
+ logger.status("(Image %s) Using Hashed Target Face(s) Model...", i)
+ target_face = TARGET_FACES_LIST[i]
+
+
+ # logger.status(f"Analyzing Target Image {i}...")
+ # target_face = analyze_faces(target_img)
+ if target_face is not None:
+ target_faces.append(target_face)
+
+ # No use in trying to swap faces if no faces are found, enhancement
+ if len(target_faces) == 0:
+ logger.status("Cannot detect any Target, skipping swapping...")
+ return result_images
+
+ if source_img is not None:
+ # separated management of wrong_gender between source and target, enhancement
+ source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source, order=faces_order[1])
+ else:
+ # source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]]
+ source_face = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse = True)[source_faces_index[0]]
+ src_wrong_gender = 0
+
+ if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index):
+ logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.')
+ elif source_face is not None:
+ results = target_imgs
+ model_path = model_path = os.path.join(insightface_path, model)
+ face_swapper = getFaceSwapModel(model_path)
+
+ source_face_idx = 0
+
+ for face_num in faces_index:
+ # No use in trying to swap faces if no further faces are found, enhancement
+ if face_num >= len(target_faces):
+ logger.status("Checked all existing target faces, skipping swapping...")
+ break
+
+ if len(source_faces_index) > 1 and source_face_idx > 0:
+ source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source, order=faces_order[1])
+ source_face_idx += 1
+
+ if source_face is not None and src_wrong_gender == 0:
+ # Reading results to make current face swap on a previous face result
+ for i, (target_img, target_face) in enumerate(zip(results, target_faces)):
+ target_face_single, wrong_gender = get_face_single(target_img, target_face, face_index=face_num, gender_target=gender_target, order=faces_order[0])
+ if target_face_single is not None and wrong_gender == 0:
+ result = target_img
+ logger.status(f"Swapping {i}...")
+ if face_boost_enabled:
+ logger.status(f"Face Boost is enabled")
+ bgr_fake, M = face_swapper.get(target_img, target_face_single, source_face, paste_back=False)
+ bgr_fake, scale = restorer.get_restored_face(bgr_fake, face_restore_model, face_restore_visibility, codeformer_weight, interpolation)
+ M *= scale
+ result = swapper.in_swap(target_img, bgr_fake, M)
+ else:
+ # logger.status(f"Swapping as-is")
+ result = face_swapper.get(target_img, target_face_single, source_face)
+ results[i] = result
+ elif wrong_gender == 1:
+ wrong_gender = 0
+ logger.status("Wrong target gender detected")
+ continue
+ else:
+ logger.status(f"No target face found for {face_num}")
+ elif src_wrong_gender == 1:
+ src_wrong_gender = 0
+ logger.status("Wrong source gender detected")
+ continue
+ else:
+ logger.status(f"No source face found for face number {source_face_idx}.")
+
+ result_images = [Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) for result in results]
+
+ else:
+ logger.status("No source face(s) in the provided Index")
+ else:
+ logger.status("No source face(s) found")
+ return result_images
diff --git a/scripts/reactor_version.py b/scripts/reactor_version.py new file mode 100644 index 0000000..b4e6267 --- /dev/null +++ b/scripts/reactor_version.py @@ -0,0 +1,13 @@ +app_title = "ReActor Node for ComfyUI"
+version_flag = "v0.5.2-a2"
+
+COLORS = {
+ "CYAN": "\033[0;36m", # CYAN
+ "ORANGE": "\033[38;5;173m", # Calm ORANGE
+ "GREEN": "\033[0;32m", # GREEN
+ "YELLOW": "\033[0;33m", # YELLOW
+ "RED": "\033[0;91m", # RED
+ "0": "\033[0m", # RESET COLOR
+}
+
+print(f"{COLORS['YELLOW']}[ReActor]{COLORS['0']} - {COLORS['ORANGE']}STATUS{COLORS['0']} - {COLORS['GREEN']}Running {version_flag} in ComfyUI{COLORS['0']}")
|