summaryrefslogtreecommitdiffstats
path: root/r_chainner/model_loading.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_chainner/model_loading.py')
-rw-r--r--r_chainner/model_loading.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/r_chainner/model_loading.py b/r_chainner/model_loading.py
new file mode 100644
index 0000000..21fd51d
--- /dev/null
+++ b/r_chainner/model_loading.py
@@ -0,0 +1,28 @@
+from r_chainner.archs.face.gfpganv1_clean_arch import GFPGANv1Clean
+from r_chainner.types import PyTorchModel
+
+
+class UnsupportedModel(Exception):
+ pass
+
+
+def load_state_dict(state_dict) -> PyTorchModel:
+
+ state_dict_keys = list(state_dict.keys())
+
+ if "params_ema" in state_dict_keys:
+ state_dict = state_dict["params_ema"]
+ elif "params-ema" in state_dict_keys:
+ state_dict = state_dict["params-ema"]
+ elif "params" in state_dict_keys:
+ state_dict = state_dict["params"]
+
+ state_dict_keys = list(state_dict.keys())
+
+ # GFPGAN
+ if (
+ "toRGB.0.weight" in state_dict_keys
+ and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
+ ):
+ model = GFPGANv1Clean(state_dict)
+ return model