ultralytics 8.0.151
add DOTAv2.yaml
for OBB training (#4258)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
@ -357,14 +357,15 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
|
||||
def xyxy2xywh(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
|
||||
top-left corner and (x2, y2) is the bottom-right corner.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
||||
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||
@ -382,11 +383,13 @@ def xywh2xyxy(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
dw = x[..., 2] / 2 # half-width
|
||||
dh = x[..., 3] / 2 # half-height
|
||||
y[..., 0] = x[..., 0] - dw # top left x
|
||||
y[..., 1] = x[..., 1] - dh # top left y
|
||||
y[..., 2] = x[..., 0] + dw # bottom right x
|
||||
y[..., 3] = x[..., 1] + dh # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
@ -404,7 +407,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
||||
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
||||
@ -428,7 +431,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
"""
|
||||
if clip:
|
||||
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
||||
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
||||
@ -449,7 +452,7 @@ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The x and y coordinates of the top left corner of the bounding box
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[..., 0] = w * x[..., 0] + padw # top left x
|
||||
y[..., 1] = h * x[..., 1] + padh # top left y
|
||||
return y
|
||||
@ -464,7 +467,7 @@ def xywh2ltwh(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||
return y
|
||||
@ -479,7 +482,7 @@ def xyxy2ltwh(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||||
return y
|
||||
@ -492,12 +495,91 @@ def ltwh2xywh(x):
|
||||
Args:
|
||||
x (torch.Tensor): the input tensor
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
|
||||
y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
|
||||
return y
|
||||
|
||||
|
||||
def xyxyxyxy2xywhr(corners):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
|
||||
|
||||
Args:
|
||||
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
"""
|
||||
if isinstance(corners, torch.Tensor):
|
||||
is_numpy = False
|
||||
atan2 = torch.atan2
|
||||
sqrt = torch.sqrt
|
||||
else:
|
||||
is_numpy = True
|
||||
atan2 = np.arctan2
|
||||
sqrt = np.sqrt
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
|
||||
cx = (x1 + x3) / 2
|
||||
cy = (y1 + y3) / 2
|
||||
dx21 = x2 - x1
|
||||
dy21 = y2 - y1
|
||||
|
||||
w = sqrt(dx21 ** 2 + dy21 ** 2)
|
||||
h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
|
||||
|
||||
rotation = atan2(-dy21, dx21)
|
||||
rotation *= 180.0 / math.pi # radians to degrees
|
||||
|
||||
return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1)
|
||||
|
||||
|
||||
def xywhr2xyxyxyxy(center):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
|
||||
|
||||
Args:
|
||||
center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
|
||||
"""
|
||||
if isinstance(center, torch.Tensor):
|
||||
is_numpy = False
|
||||
cos = torch.cos
|
||||
sin = torch.sin
|
||||
else:
|
||||
is_numpy = True
|
||||
cos = np.cos
|
||||
sin = np.sin
|
||||
|
||||
cx, cy, w, h, rotation = center.T
|
||||
rotation *= math.pi / 180.0 # degrees to radians
|
||||
|
||||
dx = w / 2
|
||||
dy = h / 2
|
||||
|
||||
cos_rot = cos(rotation)
|
||||
sin_rot = sin(rotation)
|
||||
dx_cos_rot = dx * cos_rot
|
||||
dx_sin_rot = dx * sin_rot
|
||||
dy_cos_rot = dy * cos_rot
|
||||
dy_sin_rot = dy * sin_rot
|
||||
|
||||
x1 = cx - dx_cos_rot - dy_sin_rot
|
||||
y1 = cy + dx_sin_rot - dy_cos_rot
|
||||
x2 = cx + dx_cos_rot - dy_sin_rot
|
||||
y2 = cy - dx_sin_rot - dy_cos_rot
|
||||
x3 = cx + dx_cos_rot + dy_sin_rot
|
||||
y3 = cy - dx_sin_rot + dy_cos_rot
|
||||
x4 = cx - dx_cos_rot + dy_sin_rot
|
||||
y4 = cy + dx_sin_rot + dy_cos_rot
|
||||
|
||||
return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack(
|
||||
(x1, y1, x2, y2, x3, y3, x4, y4), dim=1)
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
"""
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
@ -508,7 +590,7 @@ def ltwh2xyxy(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[:, 2] = x[:, 2] + x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] + x[:, 1] # height
|
||||
return y
|
||||
|
Reference in New Issue
Block a user