ultralytics 8.0.158 add benchmarks to coverage (#4432)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
Glenn Jocher
2023-08-20 20:52:30 +02:00
committed by GitHub
parent 495806565d
commit 87ce15d383
51 changed files with 352 additions and 482 deletions

View File

@ -9,6 +9,19 @@ from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on an RT-DETR detection model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.rtdetr import RTDETRPredictor
args = dict(model='rtdetr-l.pt', source=ASSETS)
predictor = RTDETRPredictor(overrides=args)
predictor.predict_cli()
```
"""
def postprocess(self, preds, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
@ -38,7 +51,9 @@ class RTDETRPredictor(BasePredictor):
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Return: A list of transformed imgs.
Notes: The size must be square(640) and scaleFilled.
Returns:
(list): A list of transformed imgs.
"""
# The size must be square(640) and scaleFilled.
return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im]

View File

@ -6,12 +6,28 @@ import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils import DEFAULT_CFG, RANK, colorstr
from ultralytics.utils import RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
Notes:
- F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Example:
```python
from ultralytics.models.rtdetr.train import RTDETRTrainer
args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
```
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
@ -54,27 +70,3 @@ class RTDETRTrainer(DetectionTrainer):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize RTDETR model given training data and device."""
model = 'rtdetr-l.yaml'
data = cfg.data or 'coco8.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
# NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
# NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
args = dict(model=model,
data=data,
device=device,
imgsz=640,
exist_ok=True,
batch=4,
deterministic=False,
amp=False)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -67,6 +67,18 @@ class RTDETRDataset(YOLODataset):
class RTDETRValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on an RT-DETR detection model.
Example:
```python
from ultralytics.models.rtdetr import RTDETRValidator
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
validator = RTDETRValidator(args=args)
validator(model=args['model'])
```
"""
def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset