`ultralytics 8.0.31` updates and fixes (#857)

Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kalen Michael <kalenmike@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 2e7a533ac3
commit f5d003d05a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.30" __version__ = "8.0.31"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -4,67 +4,62 @@ import requests
from ultralytics.hub.auth import Auth from ultralytics.hub.auth import Auth
from ultralytics.hub.session import HubTrainingSession from ultralytics.hub.session import HubTrainingSession
from ultralytics.hub.utils import PREFIX, split_key from ultralytics.hub.utils import split_key
from ultralytics.yolo.utils import LOGGER, emojis from ultralytics.yolo.engine.exporter import export_formats
from ultralytics.yolo.v8.detect import DetectionTrainer from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import LOGGER, emojis, PREFIX
# Define all export formats
EXPORT_FORMATS = list(export_formats()['Argument'][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
def start(key=''):
# Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
def request_api_key(attempts=0):
"""Prompt the user to input their API key"""
import getpass
max_attempts = 3
tries = f"Attempt {str(attempts + 1)} of {max_attempts}" if attempts > 0 else ""
LOGGER.info(f"{PREFIX}Login. {tries}")
input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
auth.api_key, model_id = split_key(input_key)
if not auth.authenticate():
attempts += 1
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
if attempts < max_attempts:
return request_api_key(attempts)
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
else:
return model_id
def start(key=""):
"""
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
"""
auth = Auth(key)
try: try:
api_key, model_id = split_key(key)
auth = Auth(api_key) # attempts cookie login if no api key is present
attempts = 1 if len(key) else 0
if not auth.get_state(): if not auth.get_state():
if len(key): model_id = request_api_key(auth)
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n") else:
model_id = request_api_key(attempts) _, model_id = split_key(key)
LOGGER.info(f"{PREFIX}Authenticated ✅")
if not model_id: if not model_id:
raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌')) raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌'))
session = HubTrainingSession(model_id=model_id, auth=auth) session = HubTrainingSession(model_id=model_id, auth=auth)
session.check_disk_space() session.check_disk_space()
# TODO: refactor, hardcoded for v8 trainer = YOLO(session.input_file)
args = session.model.copy()
args.pop("id")
args.pop("status")
args.pop("weights")
args["data"] = "coco128.yaml"
args["model"] = "yolov8n.yaml"
args["batch_size"] = 16
args["imgsz"] = 64
trainer = DetectionTrainer(overrides=args)
session.register_callbacks(trainer) session.register_callbacks(trainer)
setattr(trainer, 'hub_session', session) trainer.train(**session.train_args)
trainer.train()
except Exception as e: except Exception as e:
LOGGER.warning(f"{PREFIX}{e}") LOGGER.warning(f"{PREFIX}{e}")
def reset_model(key=''): def request_api_key(auth, max_attempts=3):
"""
Prompt the user to input their API key. Returns the model ID.
"""
import getpass
for attempts in range(max_attempts):
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
auth.api_key, model_id = split_key(input_key)
if auth.authenticate():
LOGGER.info(f"{PREFIX}Authenticated ✅")
return model_id
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
def reset_model(key=""):
# Reset a trained model to an untrained state # Reset a trained model to an untrained state
api_key, model_id = split_key(key) api_key, model_id = split_key(key)
r = requests.post('https://api.ultralytics.com/model-reset', json={"apiKey": api_key, "modelId": model_id}) r = requests.post("https://api.ultralytics.com/model-reset", json={"apiKey": api_key, "modelId": model_id})
if r.status_code == 200: if r.status_code == 200:
LOGGER.info(f"{PREFIX}model reset successfully") LOGGER.info(f"{PREFIX}model reset successfully")
@ -72,38 +67,32 @@ def reset_model(key=''):
LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}") LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}")
def export_model(key='', format='torchscript'): def export_model(key="", format="torchscript"):
# Export a model to all formats # Export a model to all formats
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
api_key, model_id = split_key(key) api_key, model_id = split_key(key)
formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs', r = requests.post("https://api.ultralytics.com/export",
'ultralytics_tflite', 'ultralytics_coreml')
assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"
r = requests.post('https://api.ultralytics.com/export',
json={ json={
"apiKey": api_key, "apiKey": api_key,
"modelId": model_id, "modelId": model_id,
"format": format}) "format": format})
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}" assert (r.status_code == 200), f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
LOGGER.info(f"{PREFIX}{format} export started ✅") LOGGER.info(f"{PREFIX}{format} export started ✅")
def get_export(key='', format='torchscript'): def get_export(key="", format="torchscript"):
# Get an exported model dictionary with download URL # Get an exported model dictionary with download URL
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
api_key, model_id = split_key(key) api_key, model_id = split_key(key)
formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs', r = requests.post("https://api.ultralytics.com/get-export",
'ultralytics_tflite', 'ultralytics_coreml')
assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"
r = requests.post('https://api.ultralytics.com/get-export',
json={ json={
"apiKey": api_key, "apiKey": api_key,
"modelId": model_id, "modelId": model_id,
"format": format}) "format": format})
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" assert (r.status_code == 200), f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
return r.json() return r.json()
# temp. For checking # temp. For checking
if __name__ == "__main__": if __name__ == "__main__":
start(key="b3fba421be84a20dbe68644e14436d1cce1b0a0aaa_HeMfHgvHsseMPhdq7Ylz") start()

@ -1,16 +1,18 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import json
import signal import signal
import sys
from pathlib import Path from pathlib import Path
from time import sleep from time import sleep, time
import requests import requests
from ultralytics import __version__ from ultralytics import __version__
from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request
from ultralytics.yolo.utils import is_colab, threaded from ultralytics.yolo.utils import is_colab, threaded, LOGGER, emojis, PREFIX
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
AGENT_NAME = (f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local")
session = None session = None
@ -19,23 +21,37 @@ class HubTrainingSession:
def __init__(self, model_id, auth): def __init__(self, model_id, auth):
self.agent_id = None # identifies which instance is communicating with server self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id self.model_id = model_id
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' self.api_url = f"{HUB_API_ROOT}/v1/models/{model_id}"
self.auth_header = auth.get_auth_header() self.auth_header = auth.get_auth_header()
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) self._rate_limits = {"metrics": 3.0, "ckpt": 900.0, "heartbeat": 300.0} # rate limits (seconds)
self.t = {} # rate limit timers (seconds) self._timers = {} # rate limit timers (seconds)
self.metrics_queue = {} # metrics queue self._metrics_queue = {} # metrics queue
self.alive = True # for heartbeats
self.model = self._get_model() self.model = self._get_model()
self._heartbeats() # start heartbeats self._start_heartbeat() # start heartbeats
signal.signal(signal.SIGTERM, self.shutdown) # register the shutdown function to be called on exit self._register_signal_handlers()
signal.signal(signal.SIGINT, self.shutdown)
def shutdown(self, *args): # noqa def _register_signal_handlers(self):
self.alive = False # stop heartbeats signal.signal(signal.SIGTERM, self._handle_signal)
signal.signal(signal.SIGINT, self._handle_signal)
def _handle_signal(self, signum, frame):
"""
Prevent heartbeats from being sent on Colab after kill.
This method does not use frame, it is included as it is
passed by signal.
"""
if self.alive is True:
LOGGER.info(f"{PREFIX}Kill signal received! ❌")
self._stop_heartbeat()
sys.exit(signum)
def _stop_heartbeat(self):
"""End the heartbeat loop"""
self.alive = False
def upload_metrics(self): def upload_metrics(self):
payload = {"metrics": self.metrics_queue.copy(), "type": "metrics"} payload = {"metrics": self._metrics_queue.copy(), "type": "metrics"}
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2) smart_request(f"{self.api_url}", json=payload, headers=self.auth_header, code=2)
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB # Upload a model to HUB
@ -44,7 +60,8 @@ class HubTrainingSession:
with open(weights, "rb") as f: with open(weights, "rb") as f:
file = f.read() file = f.read()
if final: if final:
smart_request(f'{self.api_url}/upload', smart_request(
f"{self.api_url}/upload",
data={ data={
"epoch": epoch, "epoch": epoch,
"type": "final", "type": "final",
@ -53,16 +70,19 @@ class HubTrainingSession:
headers=self.auth_header, headers=self.auth_header,
retry=10, retry=10,
timeout=3600, timeout=3600,
code=4) code=4,
)
else: else:
smart_request(f'{self.api_url}/upload', smart_request(
f"{self.api_url}/upload",
data={ data={
"epoch": epoch, "epoch": epoch,
"type": "epoch", "type": "epoch",
"isBest": bool(is_best)}, "isBest": bool(is_best)},
headers=self.auth_header, headers=self.auth_header,
files={"last.pt": file}, files={"last.pt": file},
code=3) code=3,
)
def _get_model(self): def _get_model(self):
# Returns model from database by id # Returns model from database by id
@ -70,31 +90,131 @@ class HubTrainingSession:
headers = self.auth_header headers = self.auth_header
try: try:
r = smart_request(api_url, method="get", headers=headers, thread=False, code=0) response = smart_request(api_url, method="get", headers=headers, thread=False, code=0)
data = r.json().get("data", None) data = response.json().get("data", None)
if not data:
return if data.get("status", None) == "trained":
assert data['data'], 'ERROR: Dataset may still be processing. Please wait a minute and try again.' # RF fix raise ValueError(
emojis(f"Model trained. View model at https://hub.ultralytics.com/models/{self.model_id} 🚀"))
if not data.get("data", None):
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
self.model_id = data["id"] self.model_id = data["id"]
# TODO: restore when server keys when dataset URL and GPU train is working
self.train_args = {
"batch": data["batch_size"],
"epochs": data["epochs"],
"imgsz": data["imgsz"],
"patience": data["patience"],
"device": data["device"],
"cache": data["cache"],
"data": data["data"]}
self.input_file = data.get("cfg", data["weights"])
# hack for yolov5 cfg adds u
if "cfg" in data and "yolov5" in data["cfg"]:
self.input_file = data["cfg"].replace(".yaml", "u.yaml")
return data return data
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e raise ConnectionRefusedError("ERROR: The HUB server is not online. Please try again later.") from e
except Exception:
raise
def check_disk_space(self): def check_disk_space(self):
if not check_dataset_disk_space(self.model['data']): if not check_dataset_disk_space(self.model["data"]):
raise MemoryError("Not enough disk space") raise MemoryError("Not enough disk space")
def register_callbacks(self, trainer):
trainer.add_callback("on_pretrain_routine_end", self.on_pretrain_routine_end)
trainer.add_callback("on_fit_epoch_end", self.on_fit_epoch_end)
trainer.add_callback("on_model_save", self.on_model_save)
trainer.add_callback("on_train_end", self.on_train_end)
def on_pretrain_routine_end(self, trainer):
"""
Start timer for upload rate limit.
This method does not use trainer. It is passed to all callbacks by default.
"""
# Start timer for upload rate limit
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")
self._timers = {"metrics": time(), "ckpt": time()} # start timer on self.rate_limit
def on_fit_epoch_end(self, trainer):
# Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics}
if trainer.epoch == 0:
model_info = {
"model/parameters": get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3),
"model/speed(ms)": round(trainer.validator.speed[1], 3)}
all_plots = {**all_plots, **model_info}
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - self._timers["metrics"] > self._rate_limits["metrics"]:
self.upload_metrics()
self._timers["metrics"] = time() # reset timer
self._metrics_queue = {} # reset queue
def on_model_save(self, trainer):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - self._timers["ckpt"] > self._rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {self.model_id}")
self._upload_model(trainer.epoch, trainer.last, is_best)
self._timers["ckpt"] = time() # reset timer
def on_train_end(self, trainer):
# Upload final model and metrics with exponential standoff
LOGGER.info(f"{PREFIX}Training completed successfully ✅")
LOGGER.info(f"{PREFIX}Uploading final {self.model_id}")
# hack for fetching mAP
mAP = trainer.metrics.get("metrics/mAP50-95(B)", 0)
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
self.alive = False # stop heartbeats
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB
file = None
if Path(weights).is_file():
with open(weights, "rb") as f:
file = f.read()
file_param = {"best.pt" if final else "last.pt": file}
endpoint = f"{self.api_url}/upload"
data = {"epoch": epoch}
if final:
data.update({"type": "final", "map": map})
else:
data.update({"type": "epoch", "isBest": bool(is_best)})
smart_request(
endpoint,
data=data,
files=file_param,
headers=self.auth_header,
retry=10 if final else None,
timeout=3600 if final else None,
code=4 if final else 3,
)
@threaded @threaded
def _heartbeats(self): def _start_heartbeat(self):
self.alive = True
while self.alive: while self.alive:
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', r = smart_request(
f"{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}",
json={ json={
"agent": AGENT_NAME, "agent": AGENT_NAME,
"agentId": self.agent_id}, "agentId": self.agent_id},
headers=self.auth_header, headers=self.auth_header,
retry=0, retry=0,
code=5, code=5,
thread=False) thread=False,
self.agent_id = r.json().get('data', {}).get('agentId', None) )
sleep(self.rate_limits['heartbeat']) self.agent_id = r.json().get("data", {}).get("agentId", None)
sleep(self._rate_limits["heartbeat"])

@ -172,7 +172,7 @@ class DetectionModel(BaseModel):
if nc and nc != self.yaml['nc']: if nc and nc != self.yaml['nc']:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
self.inplace = self.yaml.get('inplace', True) self.inplace = self.yaml.get('inplace', True)
@ -282,7 +282,7 @@ class ClassificationModel(BaseModel):
if nc and nc != self.yaml['nc']: if nc and nc != self.yaml['nc']:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
self.info() self.info()
@ -421,7 +421,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU() Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
if verbose: if verbose:
LOGGER.info(f"{colorstr('activation:')} {act}") # print LOGGER.info(f"{colorstr('activation:')} {act}") # print
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m = eval(m) if isinstance(m, str) else m # eval strings m = eval(m) if isinstance(m, str) else m # eval strings

@ -49,6 +49,20 @@ CLI_HELP_MSG = \
GitHub: https://github.com/ultralytics/ultralytics GitHub: https://github.com/ultralytics/ultralytics
""" """
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl'}
CFG_FRACTION_KEYS = {
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma',
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'degrees', 'translate', 'scale', 'shear', 'perspective', 'flipud',
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou'}
CFG_INT_KEYS = {
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
'line_thickness', 'workspace', 'nbs'}
CFG_BOOL_KEYS = {
'save', 'cache', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect',
'cos_lr', 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt',
'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks',
'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
def cfg2dict(cfg): def cfg2dict(cfg):
""" """
@ -88,11 +102,31 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, override
check_cfg_mismatch(cfg, overrides) check_cfg_mismatch(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
# Type checks # Special handling for numeric project/names
for k in 'project', 'name': for k in 'project', 'name':
if k in cfg and isinstance(cfg[k], (int, float)): if k in cfg and isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k]) cfg[k] = str(cfg[k])
# Type and Value checks
for k, v in cfg.items():
if v is not None: # None values may be from optional args
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
elif k in CFG_FRACTION_KEYS:
if not isinstance(v, (int, float)):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
if not (0.0 <= v <= 1.0):
raise ValueError(f"'{k}={v}' is an invalid value. "
f"Valid '{k}' values are between 0.0 and 1.0.")
elif k in CFG_INT_KEYS and not isinstance(v, int):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be an int (i.e. '{k}=0')")
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
# Return instance # Return instance
return IterableSimpleNamespace(**cfg) return IterableSimpleNamespace(**cfg)
@ -184,6 +218,7 @@ def entrypoint(debug=''):
try: try:
re.sub(r' *= *', '=', a) # remove spaces around equals sign re.sub(r' *= *', '=', a) # remove spaces around equals sign
k, v = a.split('=', 1) # split on first '=' sign k, v = a.split('=', 1) # split on first '=' sign
assert v, f"missing '{k}' value"
if k == 'cfg': # custom.yaml passed if k == 'cfg': # custom.yaml passed
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}") LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}")
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'} overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
@ -198,7 +233,7 @@ def entrypoint(debug=''):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
v = eval(v) v = eval(v)
overrides[k] = v overrides[k] = v
except (NameError, SyntaxError, ValueError) as e: except (NameError, SyntaxError, ValueError, AssertionError) as e:
raise argument_error(a) from e raise argument_error(a) from e
elif a in tasks: elif a in tasks:
@ -224,7 +259,7 @@ def entrypoint(debug=''):
mode = overrides.get('mode', None) mode = overrides.get('mode', None)
if mode is None: if mode is None:
mode = DEFAULT_CFG.mode or 'predict' mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.") LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
elif mode not in modes: elif mode not in modes:
if mode != 'checks': if mode != 'checks':
raise ValueError(emojis(f"ERROR ❌ Invalid 'mode={mode}'. Valid modes are {modes}.")) raise ValueError(emojis(f"ERROR ❌ Invalid 'mode={mode}'. Valid modes are {modes}."))
@ -237,7 +272,7 @@ def entrypoint(debug=''):
task = overrides.pop('task', None) task = overrides.pop('task', None)
if model is None: if model is None:
model = task2model.get(task, 'yolov8n.pt') model = task2model.get(task, 'yolov8n.pt')
LOGGER.warning(f"WARNING ⚠️ 'model=' is missing. Using default 'model={model}'.") LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
overrides['model'] = model overrides['model'] = model
model = YOLO(model) model = YOLO(model)
@ -251,15 +286,15 @@ def entrypoint(debug=''):
if mode == 'predict' and 'source' not in overrides: if mode == 'predict' and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \ overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg" else "https://ultralytics.com/images/bus.jpg"
LOGGER.warning(f"WARNING ⚠️ 'source=' is missing. Using default 'source={overrides['source']}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'): elif mode in ('train', 'val'):
if 'data' not in overrides: if 'data' not in overrides:
overrides['data'] = task2data.get(task, DEFAULT_CFG.data) overrides['data'] = task2data.get(task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data=' is missing. Using {model.task} default 'data={overrides['data']}'.") LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.")
elif mode == 'export': elif mode == 'export':
if 'format' not in overrides: if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG.format or 'torchscript' overrides['format'] = DEFAULT_CFG.format or 'torchscript'
LOGGER.warning(f"WARNING ⚠️ 'format=' is missing. Using default 'format={overrides['format']}'.") LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
# Run command in python # Run command in python
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml # getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml

@ -132,7 +132,7 @@ class YOLODataset(BaseDataset):
for lb in labels: for lb in labels:
lb["segments"] = [] lb["segments"] = []
if len_cls == 0: if len_cls == 0:
raise ValueError(f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}") raise ValueError(f"All labels empty in {cache_path}, can not start training without labels. {HELP_URL}")
return labels return labels
# TODO: use hyp config to set all these augmentations # TODO: use hyp config to set all these augmentations

@ -131,7 +131,11 @@ def yaml_save(file='data.yaml', data=None):
with open(file, 'w') as f: with open(file, 'w') as f:
# Dump data to file in YAML format, converting Path objects to strings # Dump data to file in YAML format, converting Path objects to strings
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False) yaml.safe_dump({k: str(v) if isinstance(v, Path) else v
for k, v in data.items()},
f,
sort_keys=False,
allow_unicode=True)
def yaml_load(file='data.yaml', append_filename=False): def yaml_load(file='data.yaml', append_filename=False):
@ -164,7 +168,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
None None
""" """
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
dump = yaml.dump(yaml_dict, default_flow_style=False) dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}") LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")

@ -7,6 +7,7 @@ import sys
import tempfile import tempfile
from . import USER_CONFIG_DIR from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9
def find_free_network_port() -> int: def find_free_network_port() -> int:
@ -47,8 +48,9 @@ def generate_ddp_command(world_size, trainer):
using_cli = not file_name.endswith(".py") using_cli = not file_name.endswith(".py")
if using_cli: if using_cli:
file_name = generate_ddp_file(trainer) file_name = generate_ddp_file(trainer)
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
return [ return [
sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", f"{world_size}", "--master_port", sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
f"{find_free_network_port()}", file_name] + sys.argv[1:] f"{find_free_network_port()}", file_name] + sys.argv[1:]

@ -24,6 +24,10 @@ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_1_11 = check_version(torch.__version__, '1.11.0')
TORCH_1_12 = check_version(torch.__version__, '1.12.0')
@contextmanager @contextmanager
def torch_distributed_zero_first(local_rank: int): def torch_distributed_zero_first(local_rank: int):
@ -36,10 +40,10 @@ def torch_distributed_zero_first(local_rank: int):
dist.barrier(device_ids=[0]) dist.barrier(device_ids=[0])
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')): def smart_inference_mode():
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
def decorate(fn): def decorate(fn):
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn) return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
return decorate return decorate
@ -49,7 +53,7 @@ def DDP_model(model):
assert not check_version(torch.__version__, '1.12.0', pinned=True), \ assert not check_version(torch.__version__, '1.12.0', pinned=True), \
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \ 'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395' 'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
if check_version(torch.__version__, '1.11.0'): if TORCH_1_11:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else: else:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
@ -267,7 +271,7 @@ def init_seeds(seed=0, deterministic=False):
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 if deterministic and TORCH_1_12: # https://github.com/ultralytics/yolov5/pull/8213
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

Loading…
Cancel
Save