ultralytics 8.0.158 add benchmarks to coverage (#4432)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
Glenn Jocher
2023-08-20 20:52:30 +02:00
committed by GitHub
parent 495806565d
commit 87ce15d383
51 changed files with 352 additions and 482 deletions

View File

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import SegmentationPredictor, predict
from .train import SegmentationTrainer, train
from .val import SegmentationValidator, val
from .predict import SegmentationPredictor
from .train import SegmentationTrainer
from .val import SegmentationValidator
__all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'
__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator'

View File

@ -4,10 +4,23 @@ import torch
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import ASSETS, DEFAULT_CFG, ops
from ultralytics.utils import DEFAULT_CFG, ops
class SegmentationPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a segmentation model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.segment import SegmentationPredictor
args = dict(model='yolov8n-seg.pt', source=ASSETS)
predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
@ -42,21 +55,3 @@ class SegmentationPredictor(DetectionPredictor):
results.append(
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO object detection on an image or video source."""
model = cfg.model or 'yolov8n-seg.pt'
source = cfg.source or ASSETS
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -9,6 +9,18 @@ from ultralytics.utils.plotting import plot_images, plot_results
class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationTrainer
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
trainer = SegmentationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
@ -46,19 +58,11 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
def train(cfg=DEFAULT_CFG, use_python=False):
def train(cfg=DEFAULT_CFG):
"""Train a YOLO segmentation model based on passed arguments."""
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco8-seg.yaml'
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = SegmentationTrainer(overrides=args)
trainer.train()
args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
trainer = SegmentationTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':

View File

@ -15,6 +15,18 @@ from ultralytics.utils.plotting import output_to_target, plot_images
class SegmentationValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationValidator
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
validator = SegmentationValidator(args=args)
validator(model=args['model'])
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
@ -233,18 +245,11 @@ class SegmentationValidator(DetectionValidator):
return stats
def val(cfg=DEFAULT_CFG, use_python=False):
def val(cfg=DEFAULT_CFG):
"""Validate trained YOLO model on validation data."""
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco8-seg.yaml'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = SegmentationValidator(args=args)
validator(model=args['model'])
args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
validator = SegmentationValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':