@ -1,82 +1,13 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
from copy import deepcopy
from itertools import product
from typing import Any , Dict, Generator, ItemsView , List , Tuple
from typing import Any , Generator, List , Tuple
import numpy as np
import torch
class MaskData :
"""
A structure for storing masks and their related data in batched format .
Implements basic filtering and concatenation .
"""
def __init__ ( self , * * kwargs ) - > None :
""" Initialize a MaskData object, ensuring all values are supported types. """
for v in kwargs . values ( ) :
assert isinstance (
v , ( list , np . ndarray , torch . Tensor ) ) , ' MaskData only supports list, numpy arrays, and torch tensors. '
self . _stats = dict ( * * kwargs )
def __setitem__ ( self , key : str , item : Any ) - > None :
""" Set an item in the MaskData object, ensuring it is a supported type. """
assert isinstance (
item , ( list , np . ndarray , torch . Tensor ) ) , ' MaskData only supports list, numpy arrays, and torch tensors. '
self . _stats [ key ] = item
def __delitem__ ( self , key : str ) - > None :
""" Delete an item from the MaskData object. """
del self . _stats [ key ]
def __getitem__ ( self , key : str ) - > Any :
""" Get an item from the MaskData object. """
return self . _stats [ key ]
def items ( self ) - > ItemsView [ str , Any ] :
""" Return an ItemsView of the MaskData object. """
return self . _stats . items ( )
def filter ( self , keep : torch . Tensor ) - > None :
""" Filter the MaskData object based on the given boolean tensor. """
for k , v in self . _stats . items ( ) :
if v is None :
self . _stats [ k ] = None
elif isinstance ( v , torch . Tensor ) :
self . _stats [ k ] = v [ torch . as_tensor ( keep , device = v . device ) ]
elif isinstance ( v , np . ndarray ) :
self . _stats [ k ] = v [ keep . detach ( ) . cpu ( ) . numpy ( ) ]
elif isinstance ( v , list ) and keep . dtype == torch . bool :
self . _stats [ k ] = [ a for i , a in enumerate ( v ) if keep [ i ] ]
elif isinstance ( v , list ) :
self . _stats [ k ] = [ v [ i ] for i in keep ]
else :
raise TypeError ( f ' MaskData key { k } has an unsupported type { type ( v ) } . ' )
def cat ( self , new_stats : ' MaskData ' ) - > None :
""" Concatenate a new MaskData object to the current one. """
for k , v in new_stats . items ( ) :
if k not in self . _stats or self . _stats [ k ] is None :
self . _stats [ k ] = deepcopy ( v )
elif isinstance ( v , torch . Tensor ) :
self . _stats [ k ] = torch . cat ( [ self . _stats [ k ] , v ] , dim = 0 )
elif isinstance ( v , np . ndarray ) :
self . _stats [ k ] = np . concatenate ( [ self . _stats [ k ] , v ] , axis = 0 )
elif isinstance ( v , list ) :
self . _stats [ k ] = self . _stats [ k ] + deepcopy ( v )
else :
raise TypeError ( f ' MaskData key { k } has an unsupported type { type ( v ) } . ' )
def to_numpy ( self ) - > None :
""" Convert all torch tensors in the MaskData object to numpy arrays. """
for k , v in self . _stats . items ( ) :
if isinstance ( v , torch . Tensor ) :
self . _stats [ k ] = v . detach ( ) . cpu ( ) . numpy ( )
def is_box_near_crop_edge ( boxes : torch . Tensor ,
crop_box : List [ int ] ,
orig_box : List [ int ] ,
@ -91,14 +22,6 @@ def is_box_near_crop_edge(boxes: torch.Tensor,
return torch . any ( near_crop_edge , dim = 1 )
def box_xyxy_to_xywh ( box_xyxy : torch . Tensor ) - > torch . Tensor :
""" Convert bounding boxes from XYXY format to XYWH format. """
box_xywh = deepcopy ( box_xyxy )
box_xywh [ 2 ] = box_xywh [ 2 ] - box_xywh [ 0 ]
box_xywh [ 3 ] = box_xywh [ 3 ] - box_xywh [ 1 ]
return box_xywh
def batch_iterator ( batch_size : int , * args ) - > Generator [ List [ Any ] , None , None ] :
""" Yield batches of data from the input arguments. """
assert args and all ( len ( a ) == len ( args [ 0 ] ) for a in args ) , ' Batched iteration must have same-size inputs. '
@ -107,50 +30,6 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
yield [ arg [ b * batch_size : ( b + 1 ) * batch_size ] for arg in args ]
def mask_to_rle_pytorch ( tensor : torch . Tensor ) - > List [ Dict [ str , Any ] ] :
""" Encode masks as uncompressed RLEs in the format expected by pycocotools. """
# Put in fortran order and flatten h,w
b , h , w = tensor . shape
tensor = tensor . permute ( 0 , 2 , 1 ) . flatten ( 1 )
# Compute change indices
diff = tensor [ : , 1 : ] ^ tensor [ : , : - 1 ]
change_indices = diff . nonzero ( )
# Encode run length
out = [ ]
for i in range ( b ) :
cur_idxs = change_indices [ change_indices [ : , 0 ] == i , 1 ]
cur_idxs = torch . cat ( [
torch . tensor ( [ 0 ] , dtype = cur_idxs . dtype , device = cur_idxs . device ) ,
cur_idxs + 1 ,
torch . tensor ( [ h * w ] , dtype = cur_idxs . dtype , device = cur_idxs . device ) , ] )
btw_idxs = cur_idxs [ 1 : ] - cur_idxs [ : - 1 ]
counts = [ ] if tensor [ i , 0 ] == 0 else [ 0 ]
counts . extend ( btw_idxs . detach ( ) . cpu ( ) . tolist ( ) )
out . append ( { ' size ' : [ h , w ] , ' counts ' : counts } )
return out
def rle_to_mask ( rle : Dict [ str , Any ] ) - > np . ndarray :
""" Compute a binary mask from an uncompressed RLE. """
h , w = rle [ ' size ' ]
mask = np . empty ( h * w , dtype = bool )
idx = 0
parity = False
for count in rle [ ' counts ' ] :
mask [ idx : idx + count ] = parity
idx + = count
parity ^ = True
mask = mask . reshape ( w , h )
return mask . transpose ( ) # Put in C order
def area_from_rle ( rle : Dict [ str , Any ] ) - > int :
""" Calculate the area of a mask from its uncompressed RLE. """
return sum ( rle [ ' counts ' ] [ 1 : : 2 ] )
def calculate_stability_score ( masks : torch . Tensor , mask_threshold : float , threshold_offset : float ) - > torch . Tensor :
"""
Computes the stability score for a batch of masks . The stability
@ -264,16 +143,6 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
return mask , True
def coco_encode_rle ( uncompressed_rle : Dict [ str , Any ] ) - > Dict [ str , Any ] :
""" Encode uncompressed RLE (run-length encoding) to COCO RLE format. """
from pycocotools import mask as mask_utils # type: ignore
h , w = uncompressed_rle [ ' size ' ]
rle = mask_utils . frPyObjects ( uncompressed_rle , h , w )
rle [ ' counts ' ] = rle [ ' counts ' ] . decode ( ' utf-8 ' ) # Necessary to serialize with json
return rle
def batched_mask_to_box ( masks : torch . Tensor ) - > torch . Tensor :
"""
Calculates boxes in XYXY format around masks . Return [ 0 , 0 , 0 , 0 ] for