summaryrefslogtreecommitdiffstats
path: root/r_basicsr/metrics/metric_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/metrics/metric_util.py')
-rw-r--r--r_basicsr/metrics/metric_util.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/r_basicsr/metrics/metric_util.py b/r_basicsr/metrics/metric_util.py
new file mode 100644
index 0000000..0b45354
--- /dev/null
+++ b/r_basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from r_basicsr.utils import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+ """Reorder images to 'HWC' order.
+
+ If the input_order is (h, w), return (h, w, 1);
+ If the input_order is (c, h, w), return (h, w, c);
+ If the input_order is (h, w, c), return as it is.
+
+ Args:
+ img (ndarray): Input image.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ If the input image shape is (h, w), input_order will not have
+ effects. Default: 'HWC'.
+
+ Returns:
+ ndarray: reordered image.
+ """
+
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
+ if len(img.shape) == 2:
+ img = img[..., None]
+ if input_order == 'CHW':
+ img = img.transpose(1, 2, 0)
+ return img
+
+
+def to_y_channel(img):
+ """Change to Y channel of YCbCr.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+
+ Returns:
+ (ndarray): Images with range [0, 255] (float type) without round.
+ """
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 3 and img.shape[2] == 3:
+ img = bgr2ycbcr(img, y_only=True)
+ img = img[..., None]
+ return img * 255.