ultralytics 8.0.151
add DOTAv2.yaml
for OBB training (#4258)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.150'
|
||||
__version__ = '8.0.151'
|
||||
|
||||
from ultralytics.hub import start
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
|
37
ultralytics/cfg/datasets/DOTAv2.yaml
Normal file
37
ultralytics/cfg/datasets/DOTAv2.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# DOTA 2.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University
|
||||
# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv2.yaml
|
||||
# parent
|
||||
# ├── ultralytics
|
||||
# └── datasets
|
||||
# └── dota2 ← downloads here (2GB)
|
||||
|
||||
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
||||
path: ../datasets/DOTAv2 # dataset root dir
|
||||
train: images/train # train images (relative to 'path') 1411 images
|
||||
val: images/val # val images (relative to 'path') 458 images
|
||||
test: images/test # test images (optional) 937 images
|
||||
|
||||
# Classes for DOTA 2.0
|
||||
names:
|
||||
0: plane
|
||||
1: ship
|
||||
2: storage tank
|
||||
3: baseball diamond
|
||||
4: tennis court
|
||||
5: basketball court
|
||||
6: ground track field
|
||||
7: harbor
|
||||
8: bridge
|
||||
9: large vehicle
|
||||
10: small vehicle
|
||||
11: helicopter
|
||||
12: roundabout
|
||||
13: soccer ball field
|
||||
14: swimming pool
|
||||
15: container crane
|
||||
16: airport
|
||||
17: helipad
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/DOTAv2.zip
|
@ -117,6 +117,97 @@ def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keyp
|
||||
file.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
|
||||
def convert_dota_to_yolo_obb(dota_root_path: str):
|
||||
"""
|
||||
Converts DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.
|
||||
|
||||
The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the
|
||||
associated label from the original labels directory and writes new labels in YOLO OBB format to a new directory.
|
||||
|
||||
Args:
|
||||
dota_root_path (str): The root directory path of the DOTA dataset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.data.converter import convert_dota_to_yolo_obb
|
||||
|
||||
convert_dota_to_yolo_obb('path/to/DOTA')
|
||||
```
|
||||
|
||||
Notes:
|
||||
The directory structure assumed for the DOTA dataset:
|
||||
- DOTA
|
||||
- images
|
||||
- train
|
||||
- val
|
||||
- labels
|
||||
- train_original
|
||||
- val_original
|
||||
|
||||
After the function execution, the new labels will be saved in:
|
||||
- DOTA
|
||||
- labels
|
||||
- train
|
||||
- val
|
||||
"""
|
||||
dota_root_path = Path(dota_root_path)
|
||||
|
||||
# Class names to indices mapping
|
||||
class_mapping = {
|
||||
'plane': 0,
|
||||
'ship': 1,
|
||||
'storage-tank': 2,
|
||||
'baseball-diamond': 3,
|
||||
'tennis-court': 4,
|
||||
'basketball-court': 5,
|
||||
'ground-track-field': 6,
|
||||
'harbor': 7,
|
||||
'bridge': 8,
|
||||
'large-vehicle': 9,
|
||||
'small-vehicle': 10,
|
||||
'helicopter': 11,
|
||||
'roundabout': 12,
|
||||
'soccer ball-field': 13,
|
||||
'swimming-pool': 14,
|
||||
'container-crane': 15,
|
||||
'airport': 16,
|
||||
'helipad': 17}
|
||||
|
||||
def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir):
|
||||
orig_label_path = orig_label_dir / f'{image_name}.txt'
|
||||
save_path = save_dir / f'{image_name}.txt'
|
||||
|
||||
with orig_label_path.open('r') as f, save_path.open('w') as g:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 9:
|
||||
continue
|
||||
class_name = parts[8]
|
||||
class_idx = class_mapping[class_name]
|
||||
coords = [float(p) for p in parts[:8]]
|
||||
normalized_coords = [
|
||||
coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)]
|
||||
formatted_coords = ['{:.6g}'.format(coord) for coord in normalized_coords]
|
||||
g.write(f"{class_idx} {' '.join(formatted_coords)}\n")
|
||||
|
||||
for phase in ['train', 'val']:
|
||||
image_dir = dota_root_path / 'images' / phase
|
||||
orig_label_dir = dota_root_path / 'labels' / f'{phase}_original'
|
||||
save_dir = dota_root_path / 'labels' / phase
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image_paths = list(image_dir.iterdir())
|
||||
for image_path in tqdm(image_paths, desc=f'Processing {phase} images'):
|
||||
if image_path.suffix != '.png':
|
||||
continue
|
||||
image_name_without_ext = image_path.stem
|
||||
img = cv2.imread(str(image_path))
|
||||
h, w = img.shape[:2]
|
||||
convert_label(image_name_without_ext, w, h, orig_label_dir, save_dir)
|
||||
|
||||
|
||||
def rle2polygon(segmentation):
|
||||
"""
|
||||
Convert Run-Length Encoding (RLE) mask to polygon coordinates.
|
||||
@ -209,24 +300,3 @@ def merge_multi_segment(segments):
|
||||
nidx = abs(idx[1] - idx[0])
|
||||
s.append(segments[i][nidx:])
|
||||
return s
|
||||
|
||||
|
||||
def delete_dsstore(path='../datasets'):
|
||||
"""Delete Apple .DS_Store files in the specified directory and its subdirectories."""
|
||||
from pathlib import Path
|
||||
|
||||
files = list(Path(path).rglob('.DS_store'))
|
||||
print(files)
|
||||
for f in files:
|
||||
f.unlink()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
source = 'COCO'
|
||||
|
||||
if source == 'COCO':
|
||||
convert_coco(
|
||||
'../datasets/coco/annotations', # directory with *.json
|
||||
use_segments=False,
|
||||
use_keypoints=True,
|
||||
cls91to80=False)
|
||||
|
@ -24,7 +24,7 @@ from ultralytics.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for YOLO dataset format help.'
|
||||
HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.'
|
||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
|
||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
|
||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||
@ -289,9 +289,6 @@ def check_cls_dataset(dataset: str, split=''):
|
||||
- 'test' (Path): The directory path containing the test set of the dataset.
|
||||
- 'nc' (int): The number of classes in the dataset.
|
||||
- 'names' (dict): A dictionary of class names in the dataset.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
|
||||
"""
|
||||
|
||||
dataset = Path(dataset)
|
||||
@ -329,13 +326,16 @@ class HUBDatasetStats():
|
||||
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
|
||||
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
|
||||
|
||||
Usage
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.data.utils import HUBDatasetStats
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
|
||||
|
||||
stats = HUBDatasetStats('path/to/coco8.zip', task='detect') # detect dataset
|
||||
stats = HUBDatasetStats('path/to/coco8-seg.zip', task='segment') # segment dataset
|
||||
stats = HUBDatasetStats('path/to/coco8-pose.zip', task='pose') # pose dataset
|
||||
stats.get_json(save=False)
|
||||
stats.process_images()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, path='coco128.yaml', task='detect', autodownload=False):
|
||||
@ -459,11 +459,14 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
||||
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
|
||||
quality (int, optional): The image compression quality as a percentage. Default is 50%.
|
||||
|
||||
Usage:
|
||||
Example:
|
||||
```python
|
||||
from pathlib import Path
|
||||
from ultralytics.data.utils import compress_one_image
|
||||
for f in Path('/Users/glennjocher/Downloads/dataset').rglob('*.jpg'):
|
||||
|
||||
for f in Path('path/to/dataset').rglob('*.jpg'):
|
||||
compress_one_image(f)
|
||||
```
|
||||
"""
|
||||
try: # use PIL
|
||||
im = Image.open(f)
|
||||
@ -488,9 +491,12 @@ def delete_dsstore(path):
|
||||
Args:
|
||||
path (str, optional): The directory path where the ".DS_store" files should be deleted.
|
||||
|
||||
Usage:
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.data.utils import delete_dsstore
|
||||
delete_dsstore('/Users/glennjocher/Downloads/dataset')
|
||||
|
||||
delete_dsstore('path/to/dir')
|
||||
```
|
||||
|
||||
Note:
|
||||
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
|
||||
@ -505,17 +511,18 @@ def delete_dsstore(path):
|
||||
|
||||
def zip_directory(dir, use_zipfile_library=True):
|
||||
"""
|
||||
Zips a directory and saves the archive to the specified output path.
|
||||
Zips a directory and saves the archive to the specified output path. Equivalent to 'zip -r coco8.zip coco8/'
|
||||
|
||||
Args:
|
||||
dir (str): The path to the directory to be zipped.
|
||||
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
|
||||
|
||||
Usage:
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.data.utils import zip_directory
|
||||
zip_directory('/Users/glennjocher/Downloads/playground')
|
||||
|
||||
zip -r coco8-pose.zip coco8-pose
|
||||
zip_directory('/path/to/dir')
|
||||
```
|
||||
"""
|
||||
delete_dsstore(dir)
|
||||
if use_zipfile_library:
|
||||
@ -538,9 +545,12 @@ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), ann
|
||||
weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
|
||||
annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
|
||||
|
||||
Usage:
|
||||
from utils.dataloaders import autosplit
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.dataloaders import autosplit
|
||||
|
||||
autosplit()
|
||||
```
|
||||
"""
|
||||
|
||||
path = Path(path) # images dir
|
||||
|
@ -357,14 +357,15 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
|
||||
def xyxy2xywh(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
|
||||
top-left corner and (x2, y2) is the bottom-right corner.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
||||
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||
@ -382,11 +383,13 @@ def xywh2xyxy(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
dw = x[..., 2] / 2 # half-width
|
||||
dh = x[..., 3] / 2 # half-height
|
||||
y[..., 0] = x[..., 0] - dw # top left x
|
||||
y[..., 1] = x[..., 1] - dh # top left y
|
||||
y[..., 2] = x[..., 0] + dw # bottom right x
|
||||
y[..., 3] = x[..., 1] + dh # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
@ -404,7 +407,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
||||
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
||||
@ -428,7 +431,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
"""
|
||||
if clip:
|
||||
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
||||
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
||||
@ -449,7 +452,7 @@ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The x and y coordinates of the top left corner of the bounding box
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[..., 0] = w * x[..., 0] + padw # top left x
|
||||
y[..., 1] = h * x[..., 1] + padh # top left y
|
||||
return y
|
||||
@ -464,7 +467,7 @@ def xywh2ltwh(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||
return y
|
||||
@ -479,7 +482,7 @@ def xyxy2ltwh(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||||
return y
|
||||
@ -492,12 +495,91 @@ def ltwh2xywh(x):
|
||||
Args:
|
||||
x (torch.Tensor): the input tensor
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
|
||||
y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
|
||||
return y
|
||||
|
||||
|
||||
def xyxyxyxy2xywhr(corners):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
|
||||
|
||||
Args:
|
||||
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
"""
|
||||
if isinstance(corners, torch.Tensor):
|
||||
is_numpy = False
|
||||
atan2 = torch.atan2
|
||||
sqrt = torch.sqrt
|
||||
else:
|
||||
is_numpy = True
|
||||
atan2 = np.arctan2
|
||||
sqrt = np.sqrt
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
|
||||
cx = (x1 + x3) / 2
|
||||
cy = (y1 + y3) / 2
|
||||
dx21 = x2 - x1
|
||||
dy21 = y2 - y1
|
||||
|
||||
w = sqrt(dx21 ** 2 + dy21 ** 2)
|
||||
h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
|
||||
|
||||
rotation = atan2(-dy21, dx21)
|
||||
rotation *= 180.0 / math.pi # radians to degrees
|
||||
|
||||
return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1)
|
||||
|
||||
|
||||
def xywhr2xyxyxyxy(center):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
|
||||
|
||||
Args:
|
||||
center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
|
||||
"""
|
||||
if isinstance(center, torch.Tensor):
|
||||
is_numpy = False
|
||||
cos = torch.cos
|
||||
sin = torch.sin
|
||||
else:
|
||||
is_numpy = True
|
||||
cos = np.cos
|
||||
sin = np.sin
|
||||
|
||||
cx, cy, w, h, rotation = center.T
|
||||
rotation *= math.pi / 180.0 # degrees to radians
|
||||
|
||||
dx = w / 2
|
||||
dy = h / 2
|
||||
|
||||
cos_rot = cos(rotation)
|
||||
sin_rot = sin(rotation)
|
||||
dx_cos_rot = dx * cos_rot
|
||||
dx_sin_rot = dx * sin_rot
|
||||
dy_cos_rot = dy * cos_rot
|
||||
dy_sin_rot = dy * sin_rot
|
||||
|
||||
x1 = cx - dx_cos_rot - dy_sin_rot
|
||||
y1 = cy + dx_sin_rot - dy_cos_rot
|
||||
x2 = cx + dx_cos_rot - dy_sin_rot
|
||||
y2 = cy - dx_sin_rot - dy_cos_rot
|
||||
x3 = cx + dx_cos_rot + dy_sin_rot
|
||||
y3 = cy - dx_sin_rot + dy_cos_rot
|
||||
x4 = cx - dx_cos_rot + dy_sin_rot
|
||||
y4 = cy + dx_sin_rot + dy_cos_rot
|
||||
|
||||
return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack(
|
||||
(x1, y1, x2, y2, x3, y3, x4, y4), dim=1)
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
"""
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
@ -508,7 +590,7 @@ def ltwh2xyxy(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
||||
y[:, 2] = x[:, 2] + x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] + x[:, 1] # height
|
||||
return y
|
||||
|
Reference in New Issue
Block a user