From e6bd5af6a8e306a1cdef63402a77a980a04ad6e1 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:06:44 +0000 Subject: Add files via upload --- r_chainner/model_loading.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 r_chainner/model_loading.py (limited to 'r_chainner/model_loading.py') diff --git a/r_chainner/model_loading.py b/r_chainner/model_loading.py new file mode 100644 index 0000000..21fd51d --- /dev/null +++ b/r_chainner/model_loading.py @@ -0,0 +1,28 @@ +from r_chainner.archs.face.gfpganv1_clean_arch import GFPGANv1Clean +from r_chainner.types import PyTorchModel + + +class UnsupportedModel(Exception): + pass + + +def load_state_dict(state_dict) -> PyTorchModel: + + state_dict_keys = list(state_dict.keys()) + + if "params_ema" in state_dict_keys: + state_dict = state_dict["params_ema"] + elif "params-ema" in state_dict_keys: + state_dict = state_dict["params-ema"] + elif "params" in state_dict_keys: + state_dict = state_dict["params"] + + state_dict_keys = list(state_dict.keys()) + + # GFPGAN + if ( + "toRGB.0.weight" in state_dict_keys + and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys + ): + model = GFPGANv1Clean(state_dict) + return model -- cgit v1.2.3