ImageNet names, classify inference, resume fixes (#712)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-30 22:34:28 +01:00
committed by GitHub
parent aecd17d455
commit 522f1937ed
16 changed files with 1121 additions and 115 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.24"
__version__ = "8.0.25"
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops

View File

@ -226,10 +226,13 @@ class AutoBackend(nn.Module):
f"https://docs.ultralytics.com/reference/nn/")
# class names
if 'names' not in locals():
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
names = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['names'] # human-readable names
if 'names' not in locals(): # names missing
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)} # assign default
elif isinstance(names, list): # names is a list
names = dict(enumerate(names)) # convert to dict
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
self.__dict__.update(locals()) # assign all variables to self

View File

@ -162,7 +162,7 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1,
"""
# source
source, webcam, screenshot, from_img, in_memory = check_source(source)
source_type = SourceTypes(webcam, screenshot, from_img) if not in_memory else source.source_type
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img)
# Dataloader
if in_memory:

View File

@ -29,7 +29,7 @@ class SourceTypes:
class LoadStreams:
# YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
# YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = 'stream'
@ -49,9 +49,9 @@ class LoadStreams:
import pafy # noqa
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0:
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
if s == 0 and (is_colab() or is_kaggle()):
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks."
"Try running 'source=0' in a local environment.")
cap = cv2.VideoCapture(s)
if not cap.isOpened():
raise ConnectionError(f'{st}Failed to open {s}')
@ -118,7 +118,7 @@ class LoadStreams:
class LoadScreenshots:
# YOLOv8 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
# YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`
def __init__(self, source, imgsz=640, stride=32, auto=True, transforms=None):
# source = [screen_number left top width height] (pixels)
check_requirements('mss')
@ -168,7 +168,7 @@ class LoadScreenshots:
class LoadImages:
# YOLOv8 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit()

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,7 @@
from pathlib import Path
import sys
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
guess_model_task)
@ -142,7 +143,8 @@ class YOLO:
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source=source, stream=stream)
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@smart_inference_mode()
def val(self, data=None, **kwargs):

View File

@ -113,9 +113,9 @@ class BasePredictor:
else:
return list(self.stream_inference(source, model)) # merge list of Result into one
def predict_cli(self):
def predict_cli(self, source=None, model=None):
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
gen = self.stream_inference()
gen = self.stream_inference(source, model)
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass

View File

@ -28,7 +28,7 @@ class Results:
def __init__(self, boxes=None, masks=None, probs=None, orig_shape=None) -> None:
self.boxes = Boxes(boxes, orig_shape) if boxes is not None else None # native size boxes
self.masks = Masks(masks, orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = probs.softmax(0) if probs is not None else None
self.probs = probs if probs is not None else None
self.orig_shape = orig_shape
self.comp = ["boxes", "masks", "probs"]

View File

@ -21,7 +21,7 @@ from torch.optim import lr_scheduler
from tqdm import tqdm
from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis,
@ -515,14 +515,15 @@ class BaseTrainer:
def check_resume(self):
resume = self.args.resume
if resume:
last = Path(check_file(resume) if isinstance(resume, (str, Path)) else get_latest_run())
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
assert args_yaml.is_file(), \
FileNotFoundError(f'Resume checkpoint {last} not found. '
'Please pass a valid checkpoint to resume from, i.e. yolo resume=path/to/last.pt')
args = get_cfg(args_yaml) # replace
args.model, resume = str(last), True # reinstate
self.args = args
try:
last = Path(
check_file(resume) if isinstance(resume, (str,
Path)) and Path(resume).exists() else get_latest_run())
self.args = get_cfg(attempt_load_weights(last).args)
self.args.model, resume = str(last), True # reinstate
except Exception as e:
raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
"i.e. 'yolo train resume model=path/to/last.pt'") from e
self.resume = resume
def resume_training(self, ckpt):
@ -541,7 +542,7 @@ class BaseTrainer:
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
LOGGER.info(
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")

View File

@ -479,7 +479,7 @@ def set_sentry():
if SETTINGS['sync'] and \
not is_pytest_running() and \
not is_github_actions_ci() and \
(is_pip_package() or
((is_pip_package() and not is_git_dir()) or
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
import sentry_sdk # noqa
@ -493,6 +493,10 @@ def set_sentry():
before_send=before_send,
ignore_errors=[KeyboardInterrupt])
# Disable all sentry logging
for logger in "sentry_sdk", "sentry_sdk.errors":
logging.getLogger(logger).setLevel(logging.CRITICAL)
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
"""

View File

@ -52,21 +52,22 @@ def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
try:
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
results = profile(img, model, n=3, device=device)
# Fit a solution
y = [x[2] for x in results if x] # memory [2]
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
if None in results: # some sizes failed
i = results.index(None) # first fail index
if b >= batch_sizes[i]: # y intercept above failure point
b = batch_sizes[max(i - 1, 0)] # select prior safe point
if b < 1 or b > 1024: # b outside of safe range
b = batch_size
LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
return b
except Exception as e:
LOGGER.warning(f'{prefix}{e}')
# Fit a solution
y = [x[2] for x in results if x] # memory [2]
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
if None in results: # some sizes failed
i = results.index(None) # first fail index
if b >= batch_sizes[i]: # y intercept above failure point
b = batch_sizes[max(i - 1, 0)] # select prior safe point
if b < 1 or b > 1024: # b outside of safe range
b = batch_size
LOGGER.warning(f'{prefix}WARNING ⚠️ CUDA anomaly detected, recommend restart environment and retry command.')
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
return b
LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
return batch_size

View File

@ -41,7 +41,7 @@ class DetectionPredictor(BasePredictor):
if len(im.shape) == 3:
im = im[None] # expand for batch dim
self.seen += 1
im0 = im0.copy()
imc = im0.copy() if self.args.save_crop else im0
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
@ -73,7 +73,6 @@ class DetectionPredictor(BasePredictor):
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
if self.args.save_crop:
imc = im0.copy()
save_one_box(d.xyxy,
imc,
file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg',

View File

@ -43,6 +43,7 @@ class SegmentationPredictor(DetectionPredictor):
if len(im.shape) == 3:
im = im[None] # expand for batch dim
self.seen += 1
imc = im0.copy() if self.args.save_crop else im0
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
@ -91,7 +92,6 @@ class SegmentationPredictor(DetectionPredictor):
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
if self.args.save_crop:
imc = im0.copy()
save_one_box(d.xyxy,
imc,
file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg',