From e6bd5af6a8e306a1cdef63402a77a980a04ad6e1 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:06:44 +0000 Subject: Add files via upload --- scripts/r_masking/__init__.py | 0 scripts/r_masking/core.py | 647 ++++++++++++++++++++++++++++++++++++++++++ scripts/r_masking/segs.py | 22 ++ scripts/r_masking/subcore.py | 117 ++++++++ 4 files changed, 786 insertions(+) create mode 100644 scripts/r_masking/__init__.py create mode 100644 scripts/r_masking/core.py create mode 100644 scripts/r_masking/segs.py create mode 100644 scripts/r_masking/subcore.py (limited to 'scripts/r_masking') diff --git a/scripts/r_masking/__init__.py b/scripts/r_masking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/r_masking/core.py b/scripts/r_masking/core.py new file mode 100644 index 0000000..36862e1 --- /dev/null +++ b/scripts/r_masking/core.py @@ -0,0 +1,647 @@ +import numpy as np +import cv2 +import torch +import torchvision.transforms.functional as TF + +import sys as _sys +from keyword import iskeyword as _iskeyword +from operator import itemgetter as _itemgetter + +from segment_anything import SamPredictor + +from comfy import model_management + + +################################################################################ +### namedtuple +################################################################################ + +try: + from _collections import _tuplegetter +except ImportError: + _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc) + +def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): + """Returns a new subclass of tuple with named fields. + + >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point.__doc__ # docstring for the new class + 'Point(x, y)' + >>> p = Point(11, y=22) # instantiate with positional args or keywords + >>> p[0] + p[1] # indexable like a plain tuple + 33 + >>> x, y = p # unpack like a regular tuple + >>> x, y + (11, 22) + >>> p.x + p.y # fields also accessible by name + 33 + >>> d = p._asdict() # convert to a dictionary + >>> d['x'] + 11 + >>> Point(**d) # convert from a dictionary + Point(x=11, y=22) + >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields + Point(x=100, y=22) + + """ + + # Validate the field names. At the user's option, either generate an error + # message or automatically replace the field name with a valid name. + if isinstance(field_names, str): + field_names = field_names.replace(',', ' ').split() + field_names = list(map(str, field_names)) + typename = _sys.intern(str(typename)) + + if rename: + seen = set() + for index, name in enumerate(field_names): + if (not name.isidentifier() + or _iskeyword(name) + or name.startswith('_') + or name in seen): + field_names[index] = f'_{index}' + seen.add(name) + + for name in [typename] + field_names: + if type(name) is not str: + raise TypeError('Type names and field names must be strings') + if not name.isidentifier(): + raise ValueError('Type names and field names must be valid ' + f'identifiers: {name!r}') + if _iskeyword(name): + raise ValueError('Type names and field names cannot be a ' + f'keyword: {name!r}') + + seen = set() + for name in field_names: + if name.startswith('_') and not rename: + raise ValueError('Field names cannot start with an underscore: ' + f'{name!r}') + if name in seen: + raise ValueError(f'Encountered duplicate field name: {name!r}') + seen.add(name) + + field_defaults = {} + if defaults is not None: + defaults = tuple(defaults) + if len(defaults) > len(field_names): + raise TypeError('Got more default values than field names') + field_defaults = dict(reversed(list(zip(reversed(field_names), + reversed(defaults))))) + + # Variables used in the methods and docstrings + field_names = tuple(map(_sys.intern, field_names)) + num_fields = len(field_names) + arg_list = ', '.join(field_names) + if num_fields == 1: + arg_list += ',' + repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' + tuple_new = tuple.__new__ + _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip + + # Create all the named tuple methods to be added to the class namespace + + namespace = { + '_tuple_new': tuple_new, + '__builtins__': {}, + '__name__': f'namedtuple_{typename}', + } + code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' + __new__ = eval(code, namespace) + __new__.__name__ = '__new__' + __new__.__doc__ = f'Create new instance of {typename}({arg_list})' + if defaults is not None: + __new__.__defaults__ = defaults + + @classmethod + def _make(cls, iterable): + result = tuple_new(cls, iterable) + if _len(result) != num_fields: + raise TypeError(f'Expected {num_fields} arguments, got {len(result)}') + return result + + _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' + 'or iterable') + + def _replace(self, /, **kwds): + result = self._make(_map(kwds.pop, field_names, self)) + if kwds: + raise ValueError(f'Got unexpected field names: {list(kwds)!r}') + return result + + _replace.__doc__ = (f'Return a new {typename} object replacing specified ' + 'fields with new values') + + def __repr__(self): + 'Return a nicely formatted representation string' + return self.__class__.__name__ + repr_fmt % self + + def _asdict(self): + 'Return a new dict which maps field names to their values.' + return _dict(_zip(self._fields, self)) + + def __getnewargs__(self): + 'Return self as a plain tuple. Used by copy and pickle.' + return _tuple(self) + + # Modify function metadata to help with introspection and debugging + for method in ( + __new__, + _make.__func__, + _replace, + __repr__, + _asdict, + __getnewargs__, + ): + method.__qualname__ = f'{typename}.{method.__name__}' + + # Build-up the class namespace dictionary + # and use type() to build the result class + class_namespace = { + '__doc__': f'{typename}({arg_list})', + '__slots__': (), + '_fields': field_names, + '_field_defaults': field_defaults, + '__new__': __new__, + '_make': _make, + '_replace': _replace, + '__repr__': __repr__, + '_asdict': _asdict, + '__getnewargs__': __getnewargs__, + '__match_args__': field_names, + } + for index, name in enumerate(field_names): + doc = _sys.intern(f'Alias for field number {index}') + class_namespace[name] = _tuplegetter(index, doc) + + result = type(typename, (tuple,), class_namespace) + + # For pickling to work, the __module__ variable needs to be set to the frame + # where the named tuple is created. Bypass this step in environments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython), or where the user has + # specified a particular module. + if module is None: + try: + module = _sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + if module is not None: + result.__module__ = module + + return result + + +SEG = namedtuple("SEG", + ['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'], + defaults=[None]) + +def crop_ndarray4(npimg, crop_region): + x1 = crop_region[0] + y1 = crop_region[1] + x2 = crop_region[2] + y2 = crop_region[3] + + cropped = npimg[:, y1:y2, x1:x2, :] + + return cropped + +crop_tensor4 = crop_ndarray4 + +def crop_ndarray2(npimg, crop_region): + x1 = crop_region[0] + y1 = crop_region[1] + x2 = crop_region[2] + y2 = crop_region[3] + + cropped = npimg[y1:y2, x1:x2] + + return cropped + +def crop_image(image, crop_region): + return crop_tensor4(image, crop_region) + +def normalize_region(limit, startp, size): + if startp < 0: + new_endp = min(limit, size) + new_startp = 0 + elif startp + size > limit: + new_startp = max(0, limit - size) + new_endp = limit + else: + new_startp = startp + new_endp = min(limit, startp+size) + + return int(new_startp), int(new_endp) + +def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None): + x1 = bbox[0] + y1 = bbox[1] + x2 = bbox[2] + y2 = bbox[3] + + bbox_w = x2 - x1 + bbox_h = y2 - y1 + + crop_w = bbox_w * crop_factor + crop_h = bbox_h * crop_factor + + if crop_min_size is not None: + crop_w = max(crop_min_size, crop_w) + crop_h = max(crop_min_size, crop_h) + + kernel_x = x1 + bbox_w / 2 + kernel_y = y1 + bbox_h / 2 + + new_x1 = int(kernel_x - crop_w / 2) + new_y1 = int(kernel_y - crop_h / 2) + + # make sure position in (w,h) + new_x1, new_x2 = normalize_region(w, new_x1, crop_w) + new_y1, new_y2 = normalize_region(h, new_y1, crop_h) + + return [new_x1, new_y1, new_x2, new_y2] + +def create_segmasks(results): + bboxs = results[1] + segms = results[2] + confidence = results[3] + + results = [] + for i in range(len(segms)): + item = (bboxs[i], segms[i].astype(np.float32), confidence[i]) + results.append(item) + return results + +def dilate_masks(segmasks, dilation_factor, iter=1): + if dilation_factor == 0: + return segmasks + + dilated_masks = [] + kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) + + kernel = cv2.UMat(kernel) + + for i in range(len(segmasks)): + cv2_mask = segmasks[i][1] + + cv2_mask = cv2.UMat(cv2_mask) + + if dilation_factor > 0: + dilated_mask = cv2.dilate(cv2_mask, kernel, iter) + else: + dilated_mask = cv2.erode(cv2_mask, kernel, iter) + + dilated_mask = dilated_mask.get() + + item = (segmasks[i][0], dilated_mask, segmasks[i][2]) + dilated_masks.append(item) + + return dilated_masks + +def is_same_device(a, b): + a_device = torch.device(a) if isinstance(a, str) else a + b_device = torch.device(b) if isinstance(b, str) else b + return a_device.type == b_device.type and a_device.index == b_device.index + +class SafeToGPU: + def __init__(self, size): + self.size = size + + def to_device(self, obj, device): + if is_same_device(device, 'cpu'): + obj.to(device) + else: + if is_same_device(obj.device, 'cpu'): # cpu to gpu + model_management.free_memory(self.size * 1.3, device) + if model_management.get_free_memory(device) > self.size * 1.3: + try: + obj.to(device) + except: + print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]") + else: + print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]") + +def center_of_bbox(bbox): + w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] + return bbox[0] + w/2, bbox[1] + h/2 + +def sam_predict(predictor, points, plabs, bbox, threshold): + point_coords = None if not points else np.array(points) + point_labels = None if not plabs else np.array(plabs) + + box = np.array([bbox]) if bbox is not None else None + + cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) + + total_masks = [] + + selected = False + max_score = 0 + max_mask = None + for idx in range(len(scores)): + if scores[idx] > max_score: + max_score = scores[idx] + max_mask = cur_masks[idx] + + if scores[idx] >= threshold: + selected = True + total_masks.append(cur_masks[idx]) + else: + pass + + if not selected and max_mask is not None: + total_masks.append(max_mask) + + return total_masks + +def make_2d_mask(mask): + if len(mask.shape) == 4: + return mask.squeeze(0).squeeze(0) + + elif len(mask.shape) == 3: + return mask.squeeze(0) + + return mask + +def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative): + mask = make_2d_mask(mask) + + points = [] + plabs = [] + + # minimum sampling step >= 3 + y_step = max(3, int(mask.shape[0] / 20)) + x_step = max(3, int(mask.shape[1] / 20)) + + for i in range(0, len(mask), y_step): + for j in range(0, len(mask[i]), x_step): + if mask[i][j] > threshold: + points.append((x + j, y + i)) + plabs.append(1) + elif use_negative and mask[i][j] == 0: + points.append((x + j, y + i)) + plabs.append(0) + + return points, plabs + +def gen_negative_hints(w, h, x1, y1, x2, y2): + npoints = [] + nplabs = [] + + # minimum sampling step >= 3 + y_step = max(3, int(w / 20)) + x_step = max(3, int(h / 20)) + + for i in range(10, h - 10, y_step): + for j in range(10, w - 10, x_step): + if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10): + npoints.append((j, i)) + nplabs.append(0) + + return npoints, nplabs + +def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative, + mask_hint_use_negative): + [x1, y1, x2, y2] = dilated_bbox + + points = [] + plabs = [] + if detection_hint == "center-1": + points.append(center) + plabs = [1] # 1 = foreground point, 0 = background point + + elif detection_hint == "horizontal-2": + gap = (x2 - x1) / 3 + points.append((x1 + gap, center[1])) + points.append((x1 + gap * 2, center[1])) + plabs = [1, 1] + + elif detection_hint == "vertical-2": + gap = (y2 - y1) / 3 + points.append((center[0], y1 + gap)) + points.append((center[0], y1 + gap * 2)) + plabs = [1, 1] + + elif detection_hint == "rect-4": + x_gap = (x2 - x1) / 3 + y_gap = (y2 - y1) / 3 + points.append((x1 + x_gap, center[1])) + points.append((x1 + x_gap * 2, center[1])) + points.append((center[0], y1 + y_gap)) + points.append((center[0], y1 + y_gap * 2)) + plabs = [1, 1, 1, 1] + + elif detection_hint == "diamond-4": + x_gap = (x2 - x1) / 3 + y_gap = (y2 - y1) / 3 + points.append((x1 + x_gap, y1 + y_gap)) + points.append((x1 + x_gap * 2, y1 + y_gap)) + points.append((x1 + x_gap, y1 + y_gap * 2)) + points.append((x1 + x_gap * 2, y1 + y_gap * 2)) + plabs = [1, 1, 1, 1] + + elif detection_hint == "mask-point-bbox": + center = center_of_bbox(seg.bbox) + points.append(center) + plabs = [1] + + elif detection_hint == "mask-area": + points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1], + seg.cropped_mask, + mask_hint_threshold, use_small_negative) + + if mask_hint_use_negative == "Outter": + npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1], + seg.crop_region[0], seg.crop_region[1], + seg.crop_region[2], seg.crop_region[3]) + + points += npoints + plabs += nplabs + + return points, plabs + +def combine_masks2(masks): + if len(masks) == 0: + return None + else: + initial_cv2_mask = np.array(masks[0]).astype(np.uint8) + combined_cv2_mask = initial_cv2_mask + + for i in range(1, len(masks)): + cv2_mask = np.array(masks[i]).astype(np.uint8) + + if combined_cv2_mask.shape == cv2_mask.shape: + combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask) + else: + # do nothing - incompatible mask + pass + + mask = torch.from_numpy(combined_cv2_mask) + return mask + +def dilate_mask(mask, dilation_factor, iter=1): + if dilation_factor == 0: + return make_2d_mask(mask) + + mask = make_2d_mask(mask) + + kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) + + mask = cv2.UMat(mask) + kernel = cv2.UMat(kernel) + + if dilation_factor > 0: + result = cv2.dilate(mask, kernel, iter) + else: + result = cv2.erode(mask, kernel, iter) + + return result.get() + +def convert_and_stack_masks(masks): + if len(masks) == 0: + return None + + mask_tensors = [] + for mask in masks: + mask_array = np.array(mask, dtype=np.uint8) + mask_tensor = torch.from_numpy(mask_array) + mask_tensors.append(mask_tensor) + + stacked_masks = torch.stack(mask_tensors, dim=0) + stacked_masks = stacked_masks.unsqueeze(1) + + return stacked_masks + +def merge_and_stack_masks(stacked_masks, group_size): + if stacked_masks is None: + return None + + num_masks = stacked_masks.size(0) + merged_masks = [] + + for i in range(0, num_masks, group_size): + subset_masks = stacked_masks[i:i + group_size] + merged_mask = torch.any(subset_masks, dim=0) + merged_masks.append(merged_mask) + + if len(merged_masks) > 0: + merged_masks = torch.stack(merged_masks, dim=0) + + return merged_masks + +def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, + threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): + if sam_model.is_auto_mode: + device = model_management.get_torch_device() + sam_model.safe_to.to_device(sam_model, device=device) + + try: + predictor = SamPredictor(sam_model) + image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + predictor.set_image(image, "RGB") + + total_masks = [] + + use_small_negative = mask_hint_use_negative == "Small" + + # seg_shape = segs[0] + segs = segs[1] + if detection_hint == "mask-points": + points = [] + plabs = [] + + for i in range(len(segs)): + bbox = segs[i].bbox + center = center_of_bbox(bbox) + points.append(center) + + # small point is background, big point is foreground + if use_small_negative and bbox[2] - bbox[0] < 10: + plabs.append(0) + else: + plabs.append(1) + + detected_masks = sam_predict(predictor, points, plabs, None, threshold) + total_masks += detected_masks + + else: + for i in range(len(segs)): + bbox = segs[i].bbox + center = center_of_bbox(bbox) + x1 = max(bbox[0] - bbox_expansion, 0) + y1 = max(bbox[1] - bbox_expansion, 0) + x2 = min(bbox[2] + bbox_expansion, image.shape[1]) + y2 = min(bbox[3] + bbox_expansion, image.shape[0]) + + dilated_bbox = [x1, y1, x2, y2] + + points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox, + mask_hint_threshold, use_small_negative, + mask_hint_use_negative) + + detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold) + + total_masks += detected_masks + + # merge every collected masks + mask = combine_masks2(total_masks) + + finally: + if sam_model.is_auto_mode: + sam_model.cpu() + + pass + + mask_working_device = torch.device("cpu") + + if mask is not None: + mask = mask.float() + mask = dilate_mask(mask.cpu().numpy(), dilation) + mask = torch.from_numpy(mask) + mask = mask.to(device=mask_working_device) + else: + # Extracting batch, height and width + height, width, _ = image.shape + mask = torch.zeros( + (height, width), dtype=torch.float32, device=mask_working_device + ) # empty mask + + stacked_masks = convert_and_stack_masks(total_masks) + + return (mask, merge_and_stack_masks(stacked_masks, group_size=3)) + +def tensor2mask(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t + if size[3] == 1: + return t[:,:,:,0] + elif size[3] == 4: + # Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior + if torch.min(t[:, :, :, 3]).item() != 1.: + return t[:,:,:,3] + return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:] + +def tensor2rgb(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t.unsqueeze(3).repeat(1, 1, 1, 3) + if size[3] == 1: + return t.repeat(1, 1, 1, 3) + elif size[3] == 4: + return t[:, :, :, :3] + else: + return t + +def tensor2rgba(t: torch.Tensor) -> torch.Tensor: + size = t.size() + if (len(size) < 4): + return t.unsqueeze(3).repeat(1, 1, 1, 4) + elif size[3] == 1: + return t.repeat(1, 1, 1, 4) + elif size[3] == 3: + alpha_tensor = torch.ones((size[0], size[1], size[2], 1)) + return torch.cat((t, alpha_tensor), dim=3) + else: + return t diff --git a/scripts/r_masking/segs.py b/scripts/r_masking/segs.py new file mode 100644 index 0000000..60c22d7 --- /dev/null +++ b/scripts/r_masking/segs.py @@ -0,0 +1,22 @@ +def filter(segs, labels): + labels = set([label.strip() for label in labels]) + + if 'all' in labels: + return (segs, (segs[0], []), ) + else: + res_segs = [] + remained_segs = [] + + for x in segs[1]: + if x.label in labels: + res_segs.append(x) + elif 'eyes' in labels and x.label in ['left_eye', 'right_eye']: + res_segs.append(x) + elif 'eyebrows' in labels and x.label in ['left_eyebrow', 'right_eyebrow']: + res_segs.append(x) + elif 'pupils' in labels and x.label in ['left_pupil', 'right_pupil']: + res_segs.append(x) + else: + remained_segs.append(x) + + return ((segs[0], res_segs), (segs[0], remained_segs), ) diff --git a/scripts/r_masking/subcore.py b/scripts/r_masking/subcore.py new file mode 100644 index 0000000..cf7bf7d --- /dev/null +++ b/scripts/r_masking/subcore.py @@ -0,0 +1,117 @@ +import numpy as np +import cv2 +from PIL import Image + +import scripts.r_masking.core as core +from reactor_utils import tensor_to_pil + +try: + from ultralytics import YOLO +except Exception as e: + print(e) + + +def load_yolo(model_path: str): + try: + return YOLO(model_path) + except ModuleNotFoundError: + # https://github.com/ultralytics/ultralytics/issues/3856 + YOLO("yolov8n.pt") + return YOLO(model_path) + +def inference_bbox( + model, + image: Image.Image, + confidence: float = 0.3, + device: str = "", +): + pred = model(image, conf=confidence, device=device) + + bboxes = pred[0].boxes.xyxy.cpu().numpy() + cv2_image = np.array(image) + if len(cv2_image.shape) == 3: + cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing + else: + # Handle the grayscale image here + # For example, you might want to convert it to a 3-channel grayscale image for consistency: + cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_GRAY2BGR) + cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) + + segms = [] + for x0, y0, x1, y1 in bboxes: + cv2_mask = np.zeros(cv2_gray.shape, np.uint8) + cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1) + cv2_mask_bool = cv2_mask.astype(bool) + segms.append(cv2_mask_bool) + + n, m = bboxes.shape + if n == 0: + return [[], [], [], []] + + results = [[], [], [], []] + for i in range(len(bboxes)): + results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())]) + results[1].append(bboxes[i]) + results[2].append(segms[i]) + results[3].append(pred[0].boxes[i].conf.cpu().numpy()) + + return results + + +class UltraBBoxDetector: + bbox_model = None + + def __init__(self, bbox_model): + self.bbox_model = bbox_model + + def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): + drop_size = max(drop_size, 1) + detected_results = inference_bbox(self.bbox_model, tensor_to_pil(image), threshold) + segmasks = core.create_segmasks(detected_results) + + if dilation > 0: + segmasks = core.dilate_masks(segmasks, dilation) + + items = [] + h = image.shape[1] + w = image.shape[2] + + for x, label in zip(segmasks, detected_results[0]): + item_bbox = x[0] + item_mask = x[1] + + y1, x1, y2, x2 = item_bbox + + if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue + crop_region = core.make_crop_region(w, h, item_bbox, crop_factor) + + if detailer_hook is not None: + crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region) + + cropped_image = core.crop_image(image, crop_region) + cropped_mask = core.crop_ndarray2(item_mask, crop_region) + confidence = x[2] + # bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h) + + item = core.SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None) + + items.append(item) + + shape = image.shape[1], image.shape[2] + segs = shape, items + + if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): + segs = detailer_hook.post_detection(segs) + + return segs + + def detect_combined(self, image, threshold, dilation): + detected_results = inference_bbox(self.bbox_model, core.tensor2pil(image), threshold) + segmasks = core.create_segmasks(detected_results) + if dilation > 0: + segmasks = core.dilate_masks(segmasks, dilation) + + return core.combine_masks(segmasks) + + def setAux(self, x): + pass -- cgit v1.2.3