ultralytics 8.0.158 add benchmarks to coverage (#4432)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
Glenn Jocher
2023-08-20 20:52:30 +02:00
committed by GitHub
parent 495806565d
commit 87ce15d383
51 changed files with 352 additions and 482 deletions

View File

@ -26,7 +26,7 @@ class FastSAMPrompt:
import clip # for linear_assignment
except ImportError:
from ultralytics.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
check_requirements('git+https://github.com/openai/CLIP.git')
import clip
self.clip = clip
@ -91,8 +91,6 @@ class FastSAMPrompt:
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
h = y2 - y1
w = x2 - x1
return [x1, y1, x2, y2]
def plot(self,
@ -104,9 +102,11 @@ class FastSAMPrompt:
mask_random_color=True,
better_quality=True,
retina=False,
withContours=True):
with_countouers=True):
if isinstance(annotations[0], dict):
annotations = [annotation['segmentation'] for annotation in annotations]
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
result_name = os.path.basename(self.img_path)
image = self.ori_img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@ -123,41 +123,22 @@ class FastSAMPrompt:
plt.imshow(image)
if better_quality:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
if self.device == 'cpu':
annotations = np.array(annotations)
self.fast_show_mask(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
self.fast_show_mask_gpu(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if withContours:
self.fast_show_mask(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if with_countouers:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations):
@ -184,8 +165,8 @@ class FastSAMPrompt:
LOGGER.info(f'Saved to {save_path.absolute()}')
# CPU post process
@staticmethod
def fast_show_mask(
self,
annotation,
ax,
random_color=False,
@ -196,32 +177,29 @@ class FastSAMPrompt:
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
# 将annotation 按照面积 排序
n, h, w = annotation.shape # batch, height, width
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)
annotation = annotation[sorted_indices]
annotation = annotation[np.argsort(areas)]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((msak_sum, 1, 1, 3))
color = np.random.random((n, 1, 1, 3))
else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
transparency = np.ones((n, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
show = np.zeros((h, w, 4))
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
# Draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
@ -240,63 +218,6 @@ class FastSAMPrompt:
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show)
def fast_show_mask_gpu(
self,
annotation,
ax,
random_color=False,
bbox=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=False)
annotation = annotation[sorted_indices]
# 找每个位置第一个非零值下标
index = (annotation != 0).to(torch.long).argmax(dim=0)
if random_color:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
# 按index取数index指每个位置选哪个batch的数把mask_image转成一个batch的形式
show = torch.zeros((height, weight, 4)).to(annotation.device)
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
show[h_indices, w_indices, :] = mask_image[indices]
show_cpu = show.cpu().numpy()
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show_cpu)
# clip
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]

View File

@ -5,7 +5,6 @@ import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
from ultralytics.utils.ops import xyxy2xywh
class NASPredictor(BasePredictor):
@ -14,7 +13,7 @@ class NASPredictor(BasePredictor):
"""Postprocess predictions and returns a list of Results objects."""
# Cat boxes and class scores
boxes = xyxy2xywh(preds_in[0][0])
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
preds = ops.non_max_suppression(preds,

View File

@ -4,7 +4,6 @@ import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
from ultralytics.utils.ops import xyxy2xywh
__all__ = ['NASValidator']
@ -13,7 +12,7 @@ class NASValidator(DetectionValidator):
def postprocess(self, preds_in):
"""Apply Non-maximum suppression to prediction outputs."""
boxes = xyxy2xywh(preds_in[0][0])
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
return ops.non_max_suppression(preds,
self.args.conf,

View File

@ -9,6 +9,19 @@ from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on an RT-DETR detection model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.rtdetr import RTDETRPredictor
args = dict(model='rtdetr-l.pt', source=ASSETS)
predictor = RTDETRPredictor(overrides=args)
predictor.predict_cli()
```
"""
def postprocess(self, preds, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
@ -38,7 +51,9 @@ class RTDETRPredictor(BasePredictor):
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Return: A list of transformed imgs.
Notes: The size must be square(640) and scaleFilled.
Returns:
(list): A list of transformed imgs.
"""
# The size must be square(640) and scaleFilled.
return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im]

View File

@ -6,12 +6,28 @@ import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils import DEFAULT_CFG, RANK, colorstr
from ultralytics.utils import RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
Notes:
- F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Example:
```python
from ultralytics.models.rtdetr.train import RTDETRTrainer
args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
```
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
@ -54,27 +70,3 @@ class RTDETRTrainer(DetectionTrainer):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize RTDETR model given training data and device."""
model = 'rtdetr-l.yaml'
data = cfg.data or 'coco8.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
# NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
# NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
args = dict(model=model,
data=data,
device=device,
imgsz=640,
exist_ok=True,
batch=4,
deterministic=False,
amp=False)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -67,6 +67,18 @@ class RTDETRDataset(YOLODataset):
class RTDETRValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on an RT-DETR detection model.
Example:
```python
from ultralytics.models.rtdetr import RTDETRValidator
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
validator = RTDETRValidator(args=args)
validator(model=args['model'])
```
"""
def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset

View File

@ -55,12 +55,14 @@ class Predictor(BasePredictor):
return img
def pre_transform(self, im):
"""Pre-transform input image before inference.
"""
Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Return: A list of transformed imgs.
Returns:
(list): A list of transformed images.
"""
assert len(im) == 1, 'SAM model has not supported batch inference yet!'
return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im]

View File

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo.classify.predict import ClassificationPredictor, predict
from ultralytics.models.yolo.classify.train import ClassificationTrainer, train
from ultralytics.models.yolo.classify.val import ClassificationValidator, val
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.models.yolo.classify.val import ClassificationValidator
__all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'
__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator'

View File

@ -4,10 +4,26 @@ import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ASSETS, DEFAULT_CFG
from ultralytics.utils import DEFAULT_CFG
class ClassificationPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a classification model.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.classify import ClassificationPredictor
args = dict(model='yolov8n-cls.pt', source=ASSETS)
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
@ -30,21 +46,3 @@ class ClassificationPredictor(BasePredictor):
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Run YOLO model predictions on input images/videos."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
source = cfg.source or ASSETS
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -13,6 +13,21 @@ from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_di
class ClassificationTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a classification model.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example:
```python
from ultralytics.models.yolo.classify import ClassificationTrainer
args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
trainer = ClassificationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
@ -137,22 +152,3 @@ class ClassificationTrainer(BaseTrainer):
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train a YOLO classification model."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = ClassificationTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -4,12 +4,27 @@ import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a classification model.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example:
```python
from ultralytics.models.yolo.classify import ClassificationValidator
args = dict(model='yolov8n-cls.pt', data='imagenet10')
validator = ClassificationValidator(args=args)
validator(model=args['model'])
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
@ -92,21 +107,3 @@ class ClassificationValidator(BaseValidator):
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate YOLO model using custom data."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = ClassificationValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':
val()

View File

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import DetectionPredictor, predict
from .train import DetectionTrainer, train
from .val import DetectionValidator, val
from .predict import DetectionPredictor
from .train import DetectionTrainer
from .val import DetectionValidator
__all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'
__all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator'

View File

@ -4,10 +4,23 @@ import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ASSETS, DEFAULT_CFG, ops
from ultralytics.utils import ops
class DetectionPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a detection model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.detect import DetectionPredictor
args = dict(model='yolov8n.pt', source=ASSETS)
predictor = DetectionPredictor(overrides=args)
predictor.predict_cli()
```
"""
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions and returns a list of Results objects."""
@ -27,21 +40,3 @@ class DetectionPredictor(BasePredictor):
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO model inference on input image(s)."""
model = cfg.model or 'yolov8n.pt'
source = cfg.source or ASSETS
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = DetectionPredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -8,12 +8,24 @@ from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils import LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
class DetectionTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionTrainer
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
trainer = DetectionTrainer(overrides=args)
trainer.train()
```
"""
def build_dataset(self, img_path, mode='train', batch=None):
"""
@ -102,22 +114,3 @@ class DetectionTrainer(BaseTrainer):
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize YOLO model given training data and device."""
model = cfg.model or 'yolov8n.pt'
data = cfg.data or 'coco8.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = DetectionTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -8,7 +8,7 @@ import torch
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.utils.plotting import output_to_target, plot_images
@ -16,6 +16,18 @@ from ultralytics.utils.torch_utils import de_parallel
class DetectionValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionValidator
args = dict(model='yolov8n.pt', data='coco8.yaml')
validator = DetectionValidator(args=args)
validator(model=args['model'])
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings."""
@ -254,21 +266,3 @@ class DetectionValidator(BaseValidator):
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate trained YOLO model on validation dataset."""
model = cfg.model or 'yolov8n.pt'
data = cfg.data or 'coco8.yaml'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = DetectionValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':
val()

View File

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import PosePredictor, predict
from .train import PoseTrainer, train
from .val import PoseValidator, val
from .predict import PosePredictor
from .train import PoseTrainer
from .val import PoseValidator
__all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict'
__all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor'

View File

@ -2,10 +2,23 @@
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import ASSETS, DEFAULT_CFG, LOGGER, ops
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
class PosePredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a pose model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.pose import PosePredictor
args = dict(model='yolov8n-pose.pt', source=ASSETS)
predictor = PosePredictor(overrides=args)
predictor.predict_cli()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
@ -40,21 +53,3 @@ class PosePredictor(DetectionPredictor):
boxes=pred[:, :6],
keypoints=pred_kpts))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO to predict objects in an image or video."""
model = cfg.model or 'yolov8n-pose.pt'
source = cfg.source or ASSETS
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = PosePredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -9,6 +9,18 @@ from ultralytics.utils.plotting import plot_images, plot_results
class PoseTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a pose model.
Example:
```python
from ultralytics.models.yolo.pose import PoseTrainer
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
trainer = PoseTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a PoseTrainer object with specified configurations and overrides."""
@ -59,22 +71,3 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO model on the given data and device."""
model = cfg.model or 'yolov8n-pose.yaml'
data = cfg.data or 'coco8-pose.yaml'
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = PoseTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -6,13 +6,25 @@ import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class PoseValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a pose model.
Example:
```python
from ultralytics.models.yolo.pose import PoseValidator
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
validator = PoseValidator(args=args)
validator(model=args['model'])
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
@ -201,21 +213,3 @@ class PoseValidator(DetectionValidator):
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
def val(cfg=DEFAULT_CFG, use_python=False):
"""Performs validation on YOLO model using given data."""
model = cfg.model or 'yolov8n-pose.pt'
data = cfg.data or 'coco8-pose.yaml'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = PoseValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':
val()

View File

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import SegmentationPredictor, predict
from .train import SegmentationTrainer, train
from .val import SegmentationValidator, val
from .predict import SegmentationPredictor
from .train import SegmentationTrainer
from .val import SegmentationValidator
__all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'
__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator'

View File

@ -4,10 +4,23 @@ import torch
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import ASSETS, DEFAULT_CFG, ops
from ultralytics.utils import DEFAULT_CFG, ops
class SegmentationPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a segmentation model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.segment import SegmentationPredictor
args = dict(model='yolov8n-seg.pt', source=ASSETS)
predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
@ -42,21 +55,3 @@ class SegmentationPredictor(DetectionPredictor):
results.append(
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO object detection on an image or video source."""
model = cfg.model or 'yolov8n-seg.pt'
source = cfg.source or ASSETS
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -9,6 +9,18 @@ from ultralytics.utils.plotting import plot_images, plot_results
class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationTrainer
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
trainer = SegmentationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
@ -46,19 +58,11 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
def train(cfg=DEFAULT_CFG, use_python=False):
def train(cfg=DEFAULT_CFG):
"""Train a YOLO segmentation model based on passed arguments."""
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco8-seg.yaml'
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = SegmentationTrainer(overrides=args)
trainer.train()
args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
trainer = SegmentationTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':

View File

@ -15,6 +15,18 @@ from ultralytics.utils.plotting import output_to_target, plot_images
class SegmentationValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationValidator
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
validator = SegmentationValidator(args=args)
validator(model=args['model'])
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
@ -233,18 +245,11 @@ class SegmentationValidator(DetectionValidator):
return stats
def val(cfg=DEFAULT_CFG, use_python=False):
def val(cfg=DEFAULT_CFG):
"""Validate trained YOLO model on validation data."""
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco8-seg.yaml'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = SegmentationValidator(args=args)
validator(model=args['model'])
args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
validator = SegmentationValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':