diff options
author | Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> | 2025-01-17 11:06:44 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 11:06:44 +0000 |
commit | e6bd5af6a8e306a1cdef63402a77a980a04ad6e1 (patch) | |
tree | d0732226bbc22feedad9e834b2218d7d0b0eff54 /scripts/r_masking | |
parent | 495ffc4777522e40941753e3b1b79c02f84b25b4 (diff) | |
download | Comfyui-reactor-node-e6bd5af6a8e306a1cdef63402a77a980a04ad6e1.tar.gz |
Diffstat (limited to 'scripts/r_masking')
-rw-r--r-- | scripts/r_masking/__init__.py | 0 | ||||
-rw-r--r-- | scripts/r_masking/core.py | 647 | ||||
-rw-r--r-- | scripts/r_masking/segs.py | 22 | ||||
-rw-r--r-- | scripts/r_masking/subcore.py | 117 |
4 files changed, 786 insertions, 0 deletions
diff --git a/scripts/r_masking/__init__.py b/scripts/r_masking/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/r_masking/__init__.py diff --git a/scripts/r_masking/core.py b/scripts/r_masking/core.py new file mode 100644 index 0000000..36862e1 --- /dev/null +++ b/scripts/r_masking/core.py @@ -0,0 +1,647 @@ +import numpy as np
+import cv2
+import torch
+import torchvision.transforms.functional as TF
+
+import sys as _sys
+from keyword import iskeyword as _iskeyword
+from operator import itemgetter as _itemgetter
+
+from segment_anything import SamPredictor
+
+from comfy import model_management
+
+
+################################################################################
+### namedtuple
+################################################################################
+
+try:
+ from _collections import _tuplegetter
+except ImportError:
+ _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc)
+
+def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None):
+ """Returns a new subclass of tuple with named fields.
+
+ >>> Point = namedtuple('Point', ['x', 'y'])
+ >>> Point.__doc__ # docstring for the new class
+ 'Point(x, y)'
+ >>> p = Point(11, y=22) # instantiate with positional args or keywords
+ >>> p[0] + p[1] # indexable like a plain tuple
+ 33
+ >>> x, y = p # unpack like a regular tuple
+ >>> x, y
+ (11, 22)
+ >>> p.x + p.y # fields also accessible by name
+ 33
+ >>> d = p._asdict() # convert to a dictionary
+ >>> d['x']
+ 11
+ >>> Point(**d) # convert from a dictionary
+ Point(x=11, y=22)
+ >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
+ Point(x=100, y=22)
+
+ """
+
+ # Validate the field names. At the user's option, either generate an error
+ # message or automatically replace the field name with a valid name.
+ if isinstance(field_names, str):
+ field_names = field_names.replace(',', ' ').split()
+ field_names = list(map(str, field_names))
+ typename = _sys.intern(str(typename))
+
+ if rename:
+ seen = set()
+ for index, name in enumerate(field_names):
+ if (not name.isidentifier()
+ or _iskeyword(name)
+ or name.startswith('_')
+ or name in seen):
+ field_names[index] = f'_{index}'
+ seen.add(name)
+
+ for name in [typename] + field_names:
+ if type(name) is not str:
+ raise TypeError('Type names and field names must be strings')
+ if not name.isidentifier():
+ raise ValueError('Type names and field names must be valid '
+ f'identifiers: {name!r}')
+ if _iskeyword(name):
+ raise ValueError('Type names and field names cannot be a '
+ f'keyword: {name!r}')
+
+ seen = set()
+ for name in field_names:
+ if name.startswith('_') and not rename:
+ raise ValueError('Field names cannot start with an underscore: '
+ f'{name!r}')
+ if name in seen:
+ raise ValueError(f'Encountered duplicate field name: {name!r}')
+ seen.add(name)
+
+ field_defaults = {}
+ if defaults is not None:
+ defaults = tuple(defaults)
+ if len(defaults) > len(field_names):
+ raise TypeError('Got more default values than field names')
+ field_defaults = dict(reversed(list(zip(reversed(field_names),
+ reversed(defaults)))))
+
+ # Variables used in the methods and docstrings
+ field_names = tuple(map(_sys.intern, field_names))
+ num_fields = len(field_names)
+ arg_list = ', '.join(field_names)
+ if num_fields == 1:
+ arg_list += ','
+ repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')'
+ tuple_new = tuple.__new__
+ _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip
+
+ # Create all the named tuple methods to be added to the class namespace
+
+ namespace = {
+ '_tuple_new': tuple_new,
+ '__builtins__': {},
+ '__name__': f'namedtuple_{typename}',
+ }
+ code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))'
+ __new__ = eval(code, namespace)
+ __new__.__name__ = '__new__'
+ __new__.__doc__ = f'Create new instance of {typename}({arg_list})'
+ if defaults is not None:
+ __new__.__defaults__ = defaults
+
+ @classmethod
+ def _make(cls, iterable):
+ result = tuple_new(cls, iterable)
+ if _len(result) != num_fields:
+ raise TypeError(f'Expected {num_fields} arguments, got {len(result)}')
+ return result
+
+ _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence '
+ 'or iterable')
+
+ def _replace(self, /, **kwds):
+ result = self._make(_map(kwds.pop, field_names, self))
+ if kwds:
+ raise ValueError(f'Got unexpected field names: {list(kwds)!r}')
+ return result
+
+ _replace.__doc__ = (f'Return a new {typename} object replacing specified '
+ 'fields with new values')
+
+ def __repr__(self):
+ 'Return a nicely formatted representation string'
+ return self.__class__.__name__ + repr_fmt % self
+
+ def _asdict(self):
+ 'Return a new dict which maps field names to their values.'
+ return _dict(_zip(self._fields, self))
+
+ def __getnewargs__(self):
+ 'Return self as a plain tuple. Used by copy and pickle.'
+ return _tuple(self)
+
+ # Modify function metadata to help with introspection and debugging
+ for method in (
+ __new__,
+ _make.__func__,
+ _replace,
+ __repr__,
+ _asdict,
+ __getnewargs__,
+ ):
+ method.__qualname__ = f'{typename}.{method.__name__}'
+
+ # Build-up the class namespace dictionary
+ # and use type() to build the result class
+ class_namespace = {
+ '__doc__': f'{typename}({arg_list})',
+ '__slots__': (),
+ '_fields': field_names,
+ '_field_defaults': field_defaults,
+ '__new__': __new__,
+ '_make': _make,
+ '_replace': _replace,
+ '__repr__': __repr__,
+ '_asdict': _asdict,
+ '__getnewargs__': __getnewargs__,
+ '__match_args__': field_names,
+ }
+ for index, name in enumerate(field_names):
+ doc = _sys.intern(f'Alias for field number {index}')
+ class_namespace[name] = _tuplegetter(index, doc)
+
+ result = type(typename, (tuple,), class_namespace)
+
+ # For pickling to work, the __module__ variable needs to be set to the frame
+ # where the named tuple is created. Bypass this step in environments where
+ # sys._getframe is not defined (Jython for example) or sys._getframe is not
+ # defined for arguments greater than 0 (IronPython), or where the user has
+ # specified a particular module.
+ if module is None:
+ try:
+ module = _sys._getframe(1).f_globals.get('__name__', '__main__')
+ except (AttributeError, ValueError):
+ pass
+ if module is not None:
+ result.__module__ = module
+
+ return result
+
+
+SEG = namedtuple("SEG",
+ ['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'],
+ defaults=[None])
+
+def crop_ndarray4(npimg, crop_region):
+ x1 = crop_region[0]
+ y1 = crop_region[1]
+ x2 = crop_region[2]
+ y2 = crop_region[3]
+
+ cropped = npimg[:, y1:y2, x1:x2, :]
+
+ return cropped
+
+crop_tensor4 = crop_ndarray4
+
+def crop_ndarray2(npimg, crop_region):
+ x1 = crop_region[0]
+ y1 = crop_region[1]
+ x2 = crop_region[2]
+ y2 = crop_region[3]
+
+ cropped = npimg[y1:y2, x1:x2]
+
+ return cropped
+
+def crop_image(image, crop_region):
+ return crop_tensor4(image, crop_region)
+
+def normalize_region(limit, startp, size):
+ if startp < 0:
+ new_endp = min(limit, size)
+ new_startp = 0
+ elif startp + size > limit:
+ new_startp = max(0, limit - size)
+ new_endp = limit
+ else:
+ new_startp = startp
+ new_endp = min(limit, startp+size)
+
+ return int(new_startp), int(new_endp)
+
+def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None):
+ x1 = bbox[0]
+ y1 = bbox[1]
+ x2 = bbox[2]
+ y2 = bbox[3]
+
+ bbox_w = x2 - x1
+ bbox_h = y2 - y1
+
+ crop_w = bbox_w * crop_factor
+ crop_h = bbox_h * crop_factor
+
+ if crop_min_size is not None:
+ crop_w = max(crop_min_size, crop_w)
+ crop_h = max(crop_min_size, crop_h)
+
+ kernel_x = x1 + bbox_w / 2
+ kernel_y = y1 + bbox_h / 2
+
+ new_x1 = int(kernel_x - crop_w / 2)
+ new_y1 = int(kernel_y - crop_h / 2)
+
+ # make sure position in (w,h)
+ new_x1, new_x2 = normalize_region(w, new_x1, crop_w)
+ new_y1, new_y2 = normalize_region(h, new_y1, crop_h)
+
+ return [new_x1, new_y1, new_x2, new_y2]
+
+def create_segmasks(results):
+ bboxs = results[1]
+ segms = results[2]
+ confidence = results[3]
+
+ results = []
+ for i in range(len(segms)):
+ item = (bboxs[i], segms[i].astype(np.float32), confidence[i])
+ results.append(item)
+ return results
+
+def dilate_masks(segmasks, dilation_factor, iter=1):
+ if dilation_factor == 0:
+ return segmasks
+
+ dilated_masks = []
+ kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8)
+
+ kernel = cv2.UMat(kernel)
+
+ for i in range(len(segmasks)):
+ cv2_mask = segmasks[i][1]
+
+ cv2_mask = cv2.UMat(cv2_mask)
+
+ if dilation_factor > 0:
+ dilated_mask = cv2.dilate(cv2_mask, kernel, iter)
+ else:
+ dilated_mask = cv2.erode(cv2_mask, kernel, iter)
+
+ dilated_mask = dilated_mask.get()
+
+ item = (segmasks[i][0], dilated_mask, segmasks[i][2])
+ dilated_masks.append(item)
+
+ return dilated_masks
+
+def is_same_device(a, b):
+ a_device = torch.device(a) if isinstance(a, str) else a
+ b_device = torch.device(b) if isinstance(b, str) else b
+ return a_device.type == b_device.type and a_device.index == b_device.index
+
+class SafeToGPU:
+ def __init__(self, size):
+ self.size = size
+
+ def to_device(self, obj, device):
+ if is_same_device(device, 'cpu'):
+ obj.to(device)
+ else:
+ if is_same_device(obj.device, 'cpu'): # cpu to gpu
+ model_management.free_memory(self.size * 1.3, device)
+ if model_management.get_free_memory(device) > self.size * 1.3:
+ try:
+ obj.to(device)
+ except:
+ print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]")
+ else:
+ print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]")
+
+def center_of_bbox(bbox):
+ w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
+ return bbox[0] + w/2, bbox[1] + h/2
+
+def sam_predict(predictor, points, plabs, bbox, threshold):
+ point_coords = None if not points else np.array(points)
+ point_labels = None if not plabs else np.array(plabs)
+
+ box = np.array([bbox]) if bbox is not None else None
+
+ cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box)
+
+ total_masks = []
+
+ selected = False
+ max_score = 0
+ max_mask = None
+ for idx in range(len(scores)):
+ if scores[idx] > max_score:
+ max_score = scores[idx]
+ max_mask = cur_masks[idx]
+
+ if scores[idx] >= threshold:
+ selected = True
+ total_masks.append(cur_masks[idx])
+ else:
+ pass
+
+ if not selected and max_mask is not None:
+ total_masks.append(max_mask)
+
+ return total_masks
+
+def make_2d_mask(mask):
+ if len(mask.shape) == 4:
+ return mask.squeeze(0).squeeze(0)
+
+ elif len(mask.shape) == 3:
+ return mask.squeeze(0)
+
+ return mask
+
+def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative):
+ mask = make_2d_mask(mask)
+
+ points = []
+ plabs = []
+
+ # minimum sampling step >= 3
+ y_step = max(3, int(mask.shape[0] / 20))
+ x_step = max(3, int(mask.shape[1] / 20))
+
+ for i in range(0, len(mask), y_step):
+ for j in range(0, len(mask[i]), x_step):
+ if mask[i][j] > threshold:
+ points.append((x + j, y + i))
+ plabs.append(1)
+ elif use_negative and mask[i][j] == 0:
+ points.append((x + j, y + i))
+ plabs.append(0)
+
+ return points, plabs
+
+def gen_negative_hints(w, h, x1, y1, x2, y2):
+ npoints = []
+ nplabs = []
+
+ # minimum sampling step >= 3
+ y_step = max(3, int(w / 20))
+ x_step = max(3, int(h / 20))
+
+ for i in range(10, h - 10, y_step):
+ for j in range(10, w - 10, x_step):
+ if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10):
+ npoints.append((j, i))
+ nplabs.append(0)
+
+ return npoints, nplabs
+
+def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative,
+ mask_hint_use_negative):
+ [x1, y1, x2, y2] = dilated_bbox
+
+ points = []
+ plabs = []
+ if detection_hint == "center-1":
+ points.append(center)
+ plabs = [1] # 1 = foreground point, 0 = background point
+
+ elif detection_hint == "horizontal-2":
+ gap = (x2 - x1) / 3
+ points.append((x1 + gap, center[1]))
+ points.append((x1 + gap * 2, center[1]))
+ plabs = [1, 1]
+
+ elif detection_hint == "vertical-2":
+ gap = (y2 - y1) / 3
+ points.append((center[0], y1 + gap))
+ points.append((center[0], y1 + gap * 2))
+ plabs = [1, 1]
+
+ elif detection_hint == "rect-4":
+ x_gap = (x2 - x1) / 3
+ y_gap = (y2 - y1) / 3
+ points.append((x1 + x_gap, center[1]))
+ points.append((x1 + x_gap * 2, center[1]))
+ points.append((center[0], y1 + y_gap))
+ points.append((center[0], y1 + y_gap * 2))
+ plabs = [1, 1, 1, 1]
+
+ elif detection_hint == "diamond-4":
+ x_gap = (x2 - x1) / 3
+ y_gap = (y2 - y1) / 3
+ points.append((x1 + x_gap, y1 + y_gap))
+ points.append((x1 + x_gap * 2, y1 + y_gap))
+ points.append((x1 + x_gap, y1 + y_gap * 2))
+ points.append((x1 + x_gap * 2, y1 + y_gap * 2))
+ plabs = [1, 1, 1, 1]
+
+ elif detection_hint == "mask-point-bbox":
+ center = center_of_bbox(seg.bbox)
+ points.append(center)
+ plabs = [1]
+
+ elif detection_hint == "mask-area":
+ points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1],
+ seg.cropped_mask,
+ mask_hint_threshold, use_small_negative)
+
+ if mask_hint_use_negative == "Outter":
+ npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1],
+ seg.crop_region[0], seg.crop_region[1],
+ seg.crop_region[2], seg.crop_region[3])
+
+ points += npoints
+ plabs += nplabs
+
+ return points, plabs
+
+def combine_masks2(masks):
+ if len(masks) == 0:
+ return None
+ else:
+ initial_cv2_mask = np.array(masks[0]).astype(np.uint8)
+ combined_cv2_mask = initial_cv2_mask
+
+ for i in range(1, len(masks)):
+ cv2_mask = np.array(masks[i]).astype(np.uint8)
+
+ if combined_cv2_mask.shape == cv2_mask.shape:
+ combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask)
+ else:
+ # do nothing - incompatible mask
+ pass
+
+ mask = torch.from_numpy(combined_cv2_mask)
+ return mask
+
+def dilate_mask(mask, dilation_factor, iter=1):
+ if dilation_factor == 0:
+ return make_2d_mask(mask)
+
+ mask = make_2d_mask(mask)
+
+ kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8)
+
+ mask = cv2.UMat(mask)
+ kernel = cv2.UMat(kernel)
+
+ if dilation_factor > 0:
+ result = cv2.dilate(mask, kernel, iter)
+ else:
+ result = cv2.erode(mask, kernel, iter)
+
+ return result.get()
+
+def convert_and_stack_masks(masks):
+ if len(masks) == 0:
+ return None
+
+ mask_tensors = []
+ for mask in masks:
+ mask_array = np.array(mask, dtype=np.uint8)
+ mask_tensor = torch.from_numpy(mask_array)
+ mask_tensors.append(mask_tensor)
+
+ stacked_masks = torch.stack(mask_tensors, dim=0)
+ stacked_masks = stacked_masks.unsqueeze(1)
+
+ return stacked_masks
+
+def merge_and_stack_masks(stacked_masks, group_size):
+ if stacked_masks is None:
+ return None
+
+ num_masks = stacked_masks.size(0)
+ merged_masks = []
+
+ for i in range(0, num_masks, group_size):
+ subset_masks = stacked_masks[i:i + group_size]
+ merged_mask = torch.any(subset_masks, dim=0)
+ merged_masks.append(merged_mask)
+
+ if len(merged_masks) > 0:
+ merged_masks = torch.stack(merged_masks, dim=0)
+
+ return merged_masks
+
+def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation,
+ threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative):
+ if sam_model.is_auto_mode:
+ device = model_management.get_torch_device()
+ sam_model.safe_to.to_device(sam_model, device=device)
+
+ try:
+ predictor = SamPredictor(sam_model)
+ image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
+ predictor.set_image(image, "RGB")
+
+ total_masks = []
+
+ use_small_negative = mask_hint_use_negative == "Small"
+
+ # seg_shape = segs[0]
+ segs = segs[1]
+ if detection_hint == "mask-points":
+ points = []
+ plabs = []
+
+ for i in range(len(segs)):
+ bbox = segs[i].bbox
+ center = center_of_bbox(bbox)
+ points.append(center)
+
+ # small point is background, big point is foreground
+ if use_small_negative and bbox[2] - bbox[0] < 10:
+ plabs.append(0)
+ else:
+ plabs.append(1)
+
+ detected_masks = sam_predict(predictor, points, plabs, None, threshold)
+ total_masks += detected_masks
+
+ else:
+ for i in range(len(segs)):
+ bbox = segs[i].bbox
+ center = center_of_bbox(bbox)
+ x1 = max(bbox[0] - bbox_expansion, 0)
+ y1 = max(bbox[1] - bbox_expansion, 0)
+ x2 = min(bbox[2] + bbox_expansion, image.shape[1])
+ y2 = min(bbox[3] + bbox_expansion, image.shape[0])
+
+ dilated_bbox = [x1, y1, x2, y2]
+
+ points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox,
+ mask_hint_threshold, use_small_negative,
+ mask_hint_use_negative)
+
+ detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold)
+
+ total_masks += detected_masks
+
+ # merge every collected masks
+ mask = combine_masks2(total_masks)
+
+ finally:
+ if sam_model.is_auto_mode:
+ sam_model.cpu()
+
+ pass
+
+ mask_working_device = torch.device("cpu")
+
+ if mask is not None:
+ mask = mask.float()
+ mask = dilate_mask(mask.cpu().numpy(), dilation)
+ mask = torch.from_numpy(mask)
+ mask = mask.to(device=mask_working_device)
+ else:
+ # Extracting batch, height and width
+ height, width, _ = image.shape
+ mask = torch.zeros(
+ (height, width), dtype=torch.float32, device=mask_working_device
+ ) # empty mask
+
+ stacked_masks = convert_and_stack_masks(total_masks)
+
+ return (mask, merge_and_stack_masks(stacked_masks, group_size=3))
+
+def tensor2mask(t: torch.Tensor) -> torch.Tensor:
+ size = t.size()
+ if (len(size) < 4):
+ return t
+ if size[3] == 1:
+ return t[:,:,:,0]
+ elif size[3] == 4:
+ # Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior
+ if torch.min(t[:, :, :, 3]).item() != 1.:
+ return t[:,:,:,3]
+ return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:]
+
+def tensor2rgb(t: torch.Tensor) -> torch.Tensor:
+ size = t.size()
+ if (len(size) < 4):
+ return t.unsqueeze(3).repeat(1, 1, 1, 3)
+ if size[3] == 1:
+ return t.repeat(1, 1, 1, 3)
+ elif size[3] == 4:
+ return t[:, :, :, :3]
+ else:
+ return t
+
+def tensor2rgba(t: torch.Tensor) -> torch.Tensor:
+ size = t.size()
+ if (len(size) < 4):
+ return t.unsqueeze(3).repeat(1, 1, 1, 4)
+ elif size[3] == 1:
+ return t.repeat(1, 1, 1, 4)
+ elif size[3] == 3:
+ alpha_tensor = torch.ones((size[0], size[1], size[2], 1))
+ return torch.cat((t, alpha_tensor), dim=3)
+ else:
+ return t
diff --git a/scripts/r_masking/segs.py b/scripts/r_masking/segs.py new file mode 100644 index 0000000..60c22d7 --- /dev/null +++ b/scripts/r_masking/segs.py @@ -0,0 +1,22 @@ +def filter(segs, labels):
+ labels = set([label.strip() for label in labels])
+
+ if 'all' in labels:
+ return (segs, (segs[0], []), )
+ else:
+ res_segs = []
+ remained_segs = []
+
+ for x in segs[1]:
+ if x.label in labels:
+ res_segs.append(x)
+ elif 'eyes' in labels and x.label in ['left_eye', 'right_eye']:
+ res_segs.append(x)
+ elif 'eyebrows' in labels and x.label in ['left_eyebrow', 'right_eyebrow']:
+ res_segs.append(x)
+ elif 'pupils' in labels and x.label in ['left_pupil', 'right_pupil']:
+ res_segs.append(x)
+ else:
+ remained_segs.append(x)
+
+ return ((segs[0], res_segs), (segs[0], remained_segs), )
diff --git a/scripts/r_masking/subcore.py b/scripts/r_masking/subcore.py new file mode 100644 index 0000000..cf7bf7d --- /dev/null +++ b/scripts/r_masking/subcore.py @@ -0,0 +1,117 @@ +import numpy as np
+import cv2
+from PIL import Image
+
+import scripts.r_masking.core as core
+from reactor_utils import tensor_to_pil
+
+try:
+ from ultralytics import YOLO
+except Exception as e:
+ print(e)
+
+
+def load_yolo(model_path: str):
+ try:
+ return YOLO(model_path)
+ except ModuleNotFoundError:
+ # https://github.com/ultralytics/ultralytics/issues/3856
+ YOLO("yolov8n.pt")
+ return YOLO(model_path)
+
+def inference_bbox(
+ model,
+ image: Image.Image,
+ confidence: float = 0.3,
+ device: str = "",
+):
+ pred = model(image, conf=confidence, device=device)
+
+ bboxes = pred[0].boxes.xyxy.cpu().numpy()
+ cv2_image = np.array(image)
+ if len(cv2_image.shape) == 3:
+ cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing
+ else:
+ # Handle the grayscale image here
+ # For example, you might want to convert it to a 3-channel grayscale image for consistency:
+ cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_GRAY2BGR)
+ cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
+
+ segms = []
+ for x0, y0, x1, y1 in bboxes:
+ cv2_mask = np.zeros(cv2_gray.shape, np.uint8)
+ cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
+ cv2_mask_bool = cv2_mask.astype(bool)
+ segms.append(cv2_mask_bool)
+
+ n, m = bboxes.shape
+ if n == 0:
+ return [[], [], [], []]
+
+ results = [[], [], [], []]
+ for i in range(len(bboxes)):
+ results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
+ results[1].append(bboxes[i])
+ results[2].append(segms[i])
+ results[3].append(pred[0].boxes[i].conf.cpu().numpy())
+
+ return results
+
+
+class UltraBBoxDetector:
+ bbox_model = None
+
+ def __init__(self, bbox_model):
+ self.bbox_model = bbox_model
+
+ def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
+ drop_size = max(drop_size, 1)
+ detected_results = inference_bbox(self.bbox_model, tensor_to_pil(image), threshold)
+ segmasks = core.create_segmasks(detected_results)
+
+ if dilation > 0:
+ segmasks = core.dilate_masks(segmasks, dilation)
+
+ items = []
+ h = image.shape[1]
+ w = image.shape[2]
+
+ for x, label in zip(segmasks, detected_results[0]):
+ item_bbox = x[0]
+ item_mask = x[1]
+
+ y1, x1, y2, x2 = item_bbox
+
+ if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue
+ crop_region = core.make_crop_region(w, h, item_bbox, crop_factor)
+
+ if detailer_hook is not None:
+ crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region)
+
+ cropped_image = core.crop_image(image, crop_region)
+ cropped_mask = core.crop_ndarray2(item_mask, crop_region)
+ confidence = x[2]
+ # bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h)
+
+ item = core.SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None)
+
+ items.append(item)
+
+ shape = image.shape[1], image.shape[2]
+ segs = shape, items
+
+ if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
+ segs = detailer_hook.post_detection(segs)
+
+ return segs
+
+ def detect_combined(self, image, threshold, dilation):
+ detected_results = inference_bbox(self.bbox_model, core.tensor2pil(image), threshold)
+ segmasks = core.create_segmasks(detected_results)
+ if dilation > 0:
+ segmasks = core.dilate_masks(segmasks, dilation)
+
+ return core.combine_masks(segmasks)
+
+ def setAux(self, x):
+ pass
|