8.0.60
new HUB training syntax (#1753)
Co-authored-by: Rafael Pierre <97888102+rafaelvp-db@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Semih Demirel <85176438+semihhdemirel@users.noreply.github.com>
This commit is contained in:
@ -68,12 +68,14 @@ class YOLO:
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None, session=None) -> None:
|
||||
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
|
||||
"""
|
||||
Initializes the YOLO model.
|
||||
|
||||
Args:
|
||||
model (str, Path): model to load or create
|
||||
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
"""
|
||||
self._reset_callbacks()
|
||||
self.predictor = None # reuse predictor
|
||||
@ -85,10 +87,16 @@ class YOLO:
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.metrics = None # validation/training metrics
|
||||
self.session = session # HUB session
|
||||
self.session = None # HUB session
|
||||
model = str(model).strip() # strip spaces
|
||||
|
||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||
if model.startswith('https://hub.ultralytics.com/models/'):
|
||||
from ultralytics.hub import HUBTrainingSession
|
||||
self.session = HUBTrainingSession(model)
|
||||
model = self.session.model_file
|
||||
|
||||
# Load or create new YOLO model
|
||||
model = str(model).strip() # strip spaces
|
||||
suffix = Path(model).suffix
|
||||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
@ -280,6 +288,7 @@ class YOLO:
|
||||
from ultralytics.yolo.utils.benchmarks import benchmark
|
||||
overrides = self.model.args.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'benchmark'
|
||||
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
||||
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
|
||||
|
||||
@ -293,6 +302,7 @@ class YOLO:
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||
@ -309,6 +319,11 @@ class YOLO:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if self.session: # Ultralytics HUB session
|
||||
if any(kwargs):
|
||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||
kwargs = self.session.train_args
|
||||
self.session.check_disk_space()
|
||||
check_pip_update_available()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
|
@ -277,6 +277,8 @@ class Masks(SimpleClass):
|
||||
self.masks = masks # N, h, w
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def segments(self):
|
||||
# Segments-deprecated (normalized)
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
|
||||
|
Reference in New Issue
Block a user