ultralytics 8.0.99 HUB resume fix and Docs updates (#2567)

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
Glenn Jocher
2023-05-12 18:33:32 +02:00
committed by GitHub
parent 229119c376
commit db1c5885d5
19 changed files with 486 additions and 52 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.98'
__version__ = '8.0.99'
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR

View File

@ -10,6 +10,7 @@ from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.torch_utils import model_info
from ...yolo.utils.torch_utils import smart_inference_mode
from .predict import RTDETRPredictor
@ -84,6 +85,10 @@ class RTDETR:
self.metrics = validator.metrics
return validator.metrics
def info(self, verbose=True):
"""Get model info"""
return model_info(self.model, verbose=verbose)
@smart_inference_mode()
def export(self, **kwargs):
"""

View File

@ -190,6 +190,17 @@ class BaseTrainer:
else:
self._do_train(world_size)
def _pre_caching_dataset(self):
"""
Caching dataset before training to avoid NCCL timeout.
Must be done before DDP initialization.
See https://github.com/ultralytics/ultralytics/pull/2549 for details.
"""
if RANK in (-1, 0):
LOGGER.info('Pre-caching dataset to avoid NCCL timeout')
self.get_dataloader(self.trainset, batch_size=1, rank=RANK, mode='train')
self.get_dataloader(self.testset, batch_size=1, rank=-1, mode='val')
def _setup_ddp(self, world_size):
"""Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
@ -263,6 +274,7 @@ class BaseTrainer:
def _do_train(self, world_size=1):
"""Train completed, evaluate and plot if specified by arguments."""
if world_size > 1:
self._pre_caching_dataset()
self._setup_ddp(world_size)
self._setup_train(world_size)
@ -549,10 +561,15 @@ class BaseTrainer:
resume = self.args.resume
if resume:
try:
last = Path(
check_file(resume) if isinstance(resume, (str,
Path)) and Path(resume).exists() else get_latest_run())
self.args = get_cfg(attempt_load_weights(last).args)
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
last = Path(check_file(resume) if exists else get_latest_run())
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
ckpt_args = attempt_load_weights(last).args
if not Path(ckpt_args['data']).exists():
ckpt_args['data'] = self.args.data
self.args = get_cfg(ckpt_args)
self.args.model, resume = str(last), True # reinstate
except Exception as e:
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '