`ultralytics 8.0.29` DDP-cls and default arg fixes (#813)

single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 21ae321bc2
commit 7a7c8dc7b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.28" __version__ = "8.0.29"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -262,8 +262,8 @@ def entrypoint(debug=''):
LOGGER.warning(f"WARNING ⚠️ 'format=' is missing. Using default 'format={overrides['format']}'.") LOGGER.warning(f"WARNING ⚠️ 'format=' is missing. Using default 'format={overrides['format']}'.")
# Run command in python # Run command in python
cfg = get_cfg(overrides=overrides) # getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
getattr(model, mode)(**vars(cfg)) getattr(model, mode)(**overrides) # default args from model
# Special modes -------------------------------------------------------------------------------------------------------- # Special modes --------------------------------------------------------------------------------------------------------

@ -184,9 +184,6 @@ class Exporter:
y = model(im) # dry runs y = model(im) # dry runs
if self.args.half and not coreml and not xml: if self.args.half and not coreml and not xml:
im, model = im.half(), model.half() # to FP16 im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and "
f"output shape {shape} ({file_size(file):.1f} MB)")
# Warnings # Warnings
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
@ -207,6 +204,9 @@ class Exporter:
'stride': int(max(model.stride)), 'stride': int(max(model.stride)),
'names': model.names} # model metadata 'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and "
f"output shape {self.output_shape} ({file_size(file):.1f} MB)")
# Exports # Exports
f = [''] * len(fmts) # exported filenames f = [''] * len(fmts) # exported filenames
if jit: # TorchScript if jit: # TorchScript
@ -220,9 +220,8 @@ class Exporter:
if coreml: # CoreML if coreml: # CoreML
f[4], _ = self._export_coreml() f[4], _ = self._export_coreml()
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
raise NotImplementedError('YOLOv8 TensorFlow export support is still under development. ' LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export support is still under development. '
'Please consider contributing to the effort if you have TF expertise. Thank you!') 'Please consider contributing to the effort if you have TF expertise. Thank you!')
assert not isinstance(model, ClassificationModel), 'ClassificationModel TF exports not yet supported.'
nms = False nms = False
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs, f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
agnostic_nms=self.args.agnostic_nms or tfjs) agnostic_nms=self.args.agnostic_nms or tfjs)
@ -236,7 +235,7 @@ class Exporter:
agnostic_nms=self.args.agnostic_nms) agnostic_nms=self.args.agnostic_nms)
if edgetpu: if edgetpu:
f[8], _ = self._export_edgetpu() f[8], _ = self._export_edgetpu()
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(s_model.outputs)) self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
if tfjs: if tfjs:
f[9], _ = self._export_tfjs() f[9], _ = self._export_tfjs()
if paddle: # PaddlePaddle if paddle: # PaddlePaddle
@ -552,13 +551,13 @@ class Exporter:
return f, keras_model return f, keras_model
@try_export @try_export
def _export_pb(self, keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): def _export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
# YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow # YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
import tensorflow as tf # noqa import tensorflow as tf # noqa
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = file.with_suffix('.pb') f = self.file.with_suffix('.pb')
m = tf.function(lambda x: keras_model(x)) # full model m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))

@ -119,7 +119,6 @@ class YOLO:
def fuse(self): def fuse(self):
self.model.fuse() self.model.fuse()
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs): def predict(self, source=None, stream=False, **kwargs):
""" """
Perform prediction using the YOLO model. Perform prediction using the YOLO model.
@ -258,8 +257,6 @@ class YOLO:
@staticmethod @staticmethod
def _reset_ckpt_args(args): def _reset_ckpt_args(args):
for arg in 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', 'save_json', \ for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
'half', 'v5loader': 'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots':
args.pop(arg, None) args.pop(arg, None)
args["device"] = '' # set device to '' to prevent auto-DDP usage

@ -457,7 +457,7 @@ class BaseTrainer:
def get_validator(self): def get_validator(self):
raise NotImplementedError("get_validator function not implemented in trainer") raise NotImplementedError("get_validator function not implemented in trainer")
def get_dataloader(self, dataset_path, batch_size=16, rank=0): def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
""" """
Returns dataloader derived from torch.data.Dataloader. Returns dataloader derived from torch.data.Dataloader.
""" """

@ -485,18 +485,20 @@ def set_sentry():
if SETTINGS['sync'] and \ if SETTINGS['sync'] and \
RANK in {-1, 0} and \ RANK in {-1, 0} and \
sys.argv[0].endswith('yolo') and \
not is_pytest_running() and \ not is_pytest_running() and \
not is_github_actions_ci() and \ not is_github_actions_ci() and \
((is_pip_package() and not is_git_dir()) or ((is_pip_package() and not is_git_dir()) or
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")): (get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
import sentry_sdk # noqa import sentry_sdk # noqa
from ultralytics import __version__
import ultralytics
sentry_sdk.init( sentry_sdk.init(
dsn="https://1f331c322109416595df20a91f4005d3@o4504521589325824.ingest.sentry.io/4504521592406016", dsn="https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016",
debug=False, debug=False,
traces_sample_rate=1.0, traces_sample_rate=1.0,
release=ultralytics.__version__, release=__version__,
environment='production', # 'dev' or 'production' environment='production', # 'dev' or 'production'
before_send=before_send, before_send=before_send,
ignore_errors=[KeyboardInterrupt, FileNotFoundError]) ignore_errors=[KeyboardInterrupt, FileNotFoundError])

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import contextlib
import glob import glob
import inspect import inspect
import math import math
@ -7,9 +7,9 @@ import os
import platform import platform
import re import re
import shutil import shutil
import subprocess
import urllib import urllib
from pathlib import Path from pathlib import Path
from subprocess import check_output
from typing import Optional from typing import Optional
import cv2 import cv2
@ -155,12 +155,11 @@ def check_online() -> bool:
bool: True if connection is successful, False otherwise. bool: True if connection is successful, False otherwise.
""" """
import socket import socket
try: with contextlib.suppress(subprocess.CalledProcessError):
# Check host accessibility by attempting to establish a connection host = socket.gethostbyname("www.github.com")
socket.create_connection(("1.1.1.1", 443), timeout=5) socket.create_connection((host, 80), timeout=2)
return True return True
except OSError: return False
return False
def check_python(minimum: str = '3.7.0') -> bool: def check_python(minimum: str = '3.7.0') -> bool:
@ -181,6 +180,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages or single package str) # Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages or single package str)
prefix = colorstr('red', 'bold', 'requirements:') prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version check_python() # check python version
file = None
if isinstance(requirements, Path): # requirements.txt file if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve() file = requirements.resolve()
assert file.exists(), f"{prefix} {file} not found, check failed." assert file.exists(), f"{prefix} {file} not found, check failed."
@ -202,9 +202,8 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...") LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
try: try:
assert check_online(), "AutoUpdate skipped (offline)" assert check_online(), "AutoUpdate skipped (offline)"
LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode()) LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
source = file if 'file' in locals() else requirements s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
LOGGER.info(s) LOGGER.info(s)
except Exception as e: except Exception as e:
@ -306,7 +305,7 @@ def git_describe(path=ROOT): # path must be a directory
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
try: try:
assert (Path(path) / '.git').is_dir() assert (Path(path) / '.git').is_dir()
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1] return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
except AssertionError: except AssertionError:
return '' return ''

@ -246,7 +246,7 @@ def intersect_dicts(da, db, exclude=()):
def is_parallel(model): def is_parallel(model):
# Returns True if model is of type DP or DDP # Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
def de_parallel(model): def de_parallel(model):

@ -1,5 +1,4 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import sys
import torch import torch
import torchvision import torchvision
@ -9,7 +8,7 @@ from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CFG from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils.torch_utils import strip_optimizer from ultralytics.yolo.utils.torch_utils import strip_optimizer, is_parallel
class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
@ -56,7 +55,7 @@ class ClassificationTrainer(BaseTrainer):
# Load a YOLO model locally, from torchvision, or from Ultralytics assets # Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"): if model.endswith(".pt"):
self.model, _ = attempt_load_one_weight(model, device='cpu') self.model, _ = attempt_load_one_weight(model, device='cpu')
for p in model.parameters(): for p in self.model.parameters():
p.requires_grad = True # for training p.requires_grad = True # for training
elif model.endswith(".yaml"): elif model.endswith(".yaml"):
self.model = self.get_model(cfg=model) self.model = self.get_model(cfg=model)
@ -75,8 +74,12 @@ class ClassificationTrainer(BaseTrainer):
augment=mode == "train", augment=mode == "train",
rank=rank, rank=rank,
workers=self.args.workers) workers=self.args.workers)
# Attach inference transforms
if mode != "train": if mode != "train":
self.model.transforms = loader.dataset.torch_transforms # attach inference transforms if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
self.model.transforms = loader.dataset.torch_transforms
return loader return loader
def preprocess_batch(self, batch): def preprocess_batch(self, batch):

Loading…
Cancel
Save