From e6bd5af6a8e306a1cdef63402a77a980a04ad6e1 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:06:44 +0000 Subject: Add files via upload --- r_facelib/parsing/__init__.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 r_facelib/parsing/__init__.py (limited to 'r_facelib/parsing/__init__.py') diff --git a/r_facelib/parsing/__init__.py b/r_facelib/parsing/__init__.py new file mode 100644 index 0000000..e5aaa28 --- /dev/null +++ b/r_facelib/parsing/__init__.py @@ -0,0 +1,23 @@ +import torch + +from r_facelib.utils import load_file_from_url +from .bisenet import BiSeNet +from .parsenet import ParseNet + + +def init_parsing_model(model_name='bisenet', half=False, device='cuda'): + if model_name == 'bisenet': + model = BiSeNet(num_class=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' + elif model_name == 'parsenet': + model = ParseNet(in_size=512, out_size=512, parsing_ch=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='../../models/facedetection', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model -- cgit v1.2.3