ultralytics 8.0.158
add benchmarks to coverage (#4432)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
@ -9,6 +9,19 @@ from ultralytics.utils import ops
|
||||
|
||||
|
||||
class RTDETRPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on an RT-DETR detection model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.rtdetr import RTDETRPredictor
|
||||
|
||||
args = dict(model='rtdetr-l.pt', source=ASSETS)
|
||||
predictor = RTDETRPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocess predictions and returns a list of Results objects."""
|
||||
@ -38,7 +51,9 @@ class RTDETRPredictor(BasePredictor):
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
|
||||
Return: A list of transformed imgs.
|
||||
Notes: The size must be square(640) and scaleFilled.
|
||||
|
||||
Returns:
|
||||
(list): A list of transformed imgs.
|
||||
"""
|
||||
# The size must be square(640) and scaleFilled.
|
||||
return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im]
|
||||
|
@ -6,12 +6,28 @@ import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK, colorstr
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
|
||||
from .val import RTDETRDataset, RTDETRValidator
|
||||
|
||||
|
||||
class RTDETRTrainer(DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
|
||||
|
||||
Notes:
|
||||
- F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
|
||||
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.rtdetr.train import RTDETRTrainer
|
||||
|
||||
args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
|
||||
trainer = RTDETRTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return a YOLO detection model."""
|
||||
@ -54,27 +70,3 @@ class RTDETRTrainer(DetectionTrainer):
|
||||
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
|
||||
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
|
||||
return batch
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train and optimize RTDETR model given training data and device."""
|
||||
model = 'rtdetr-l.yaml'
|
||||
data = cfg.data or 'coco8.yaml' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
||||
# NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
|
||||
# NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
|
||||
args = dict(model=model,
|
||||
data=data,
|
||||
device=device,
|
||||
imgsz=640,
|
||||
exist_ok=True,
|
||||
batch=4,
|
||||
deterministic=False,
|
||||
amp=False)
|
||||
trainer = RTDETRTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
|
@ -67,6 +67,18 @@ class RTDETRDataset(YOLODataset):
|
||||
|
||||
|
||||
class RTDETRValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on an RT-DETR detection model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.rtdetr import RTDETRValidator
|
||||
|
||||
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
|
||||
validator = RTDETRValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
```
|
||||
"""
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""Build YOLO Dataset
|
||||
|
Reference in New Issue
Block a user