ImageNet names, classify inference, resume fixes (#712)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-30 22:34:28 +01:00
committed by GitHub
parent aecd17d455
commit 522f1937ed
16 changed files with 1121 additions and 115 deletions

View File

@ -2,6 +2,7 @@
from pathlib import Path
import sys
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
guess_model_task)
@ -142,7 +143,8 @@ class YOLO:
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source=source, stream=stream)
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@smart_inference_mode()
def val(self, data=None, **kwargs):

View File

@ -113,9 +113,9 @@ class BasePredictor:
else:
return list(self.stream_inference(source, model)) # merge list of Result into one
def predict_cli(self):
def predict_cli(self, source=None, model=None):
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
gen = self.stream_inference()
gen = self.stream_inference(source, model)
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass

View File

@ -28,7 +28,7 @@ class Results:
def __init__(self, boxes=None, masks=None, probs=None, orig_shape=None) -> None:
self.boxes = Boxes(boxes, orig_shape) if boxes is not None else None # native size boxes
self.masks = Masks(masks, orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = probs.softmax(0) if probs is not None else None
self.probs = probs if probs is not None else None
self.orig_shape = orig_shape
self.comp = ["boxes", "masks", "probs"]

View File

@ -21,7 +21,7 @@ from torch.optim import lr_scheduler
from tqdm import tqdm
from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis,
@ -515,14 +515,15 @@ class BaseTrainer:
def check_resume(self):
resume = self.args.resume
if resume:
last = Path(check_file(resume) if isinstance(resume, (str, Path)) else get_latest_run())
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
assert args_yaml.is_file(), \
FileNotFoundError(f'Resume checkpoint {last} not found. '
'Please pass a valid checkpoint to resume from, i.e. yolo resume=path/to/last.pt')
args = get_cfg(args_yaml) # replace
args.model, resume = str(last), True # reinstate
self.args = args
try:
last = Path(
check_file(resume) if isinstance(resume, (str,
Path)) and Path(resume).exists() else get_latest_run())
self.args = get_cfg(attempt_load_weights(last).args)
self.args.model, resume = str(last), True # reinstate
except Exception as e:
raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
"i.e. 'yolo train resume model=path/to/last.pt'") from e
self.resume = resume
def resume_training(self, ckpt):
@ -541,7 +542,7 @@ class BaseTrainer:
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
LOGGER.info(
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")