`ultralytics 8.0.44` export and task fixes (#1088)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent fe61018975
commit 3ea659411b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,9 +32,14 @@ jobs:
# key: ${{ runner.os }}-Benchmarks-${{ hashFiles('requirements.txt') }} # key: ${{ runner.os }}-Benchmarks-${{ hashFiles('requirements.txt') }}
# restore-keys: ${{ runner.os }}-Benchmarks- # restore-keys: ${{ runner.os }}-Benchmarks-
- name: Install requirements - name: Install requirements
shell: bash # for Windows compatibility
run: | run: |
python -m pip install --upgrade pip wheel python -m pip install --upgrade pip wheel
if [ "${{ matrix.os }}" == "macos-latest" ]; then
pip install -e . coremltools openvino-dev tensorflow-macos --extra-index-url https://download.pytorch.org/whl/cpu
else
pip install -e . coremltools openvino-dev tensorflow-cpu paddlepaddle x2paddle --extra-index-url https://download.pytorch.org/whl/cpu pip install -e . coremltools openvino-dev tensorflow-cpu paddlepaddle x2paddle --extra-index-url https://download.pytorch.org/whl/cpu
fi
yolo export format=tflite yolo export format=tflite
- name: Check environment - name: Check environment
run: | run: |
@ -94,6 +99,7 @@ jobs:
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip- restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip-
- name: Install requirements - name: Install requirements
shell: bash # for Windows compatibility
run: | run: |
python -m pip install --upgrade pip wheel python -m pip install --upgrade pip wheel
if [ "${{ matrix.torch }}" == "1.8.0" ]; then if [ "${{ matrix.torch }}" == "1.8.0" ]; then
@ -101,7 +107,6 @@ jobs:
else else
pip install -e '.[export]' pytest --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[export]' pytest --extra-index-url https://download.pytorch.org/whl/cpu
fi fi
shell: bash # for Windows compatibility
- name: Check environment - name: Check environment
run: | run: |
echo "RUNNER_OS is ${{ runner.os }}" echo "RUNNER_OS is ${{ runner.os }}"

@ -78,13 +78,6 @@
} }
] ]
}, },
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "ZOwTlorPd8-D"
}
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {

@ -3,7 +3,7 @@
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
CFG = 'yolov8n' CFG = 'yolov8n'
@ -49,6 +49,7 @@ def test_val_classify():
# Predict checks ------------------------------------------------------------------------------------------------------- # Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect(): def test_predict_detect():
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32") run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
if checks.check_online():
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32') run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32') run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32') run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')

@ -9,7 +9,7 @@ from PIL import Image
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.yolo.data.build import load_inference_source from ultralytics.yolo.data.build import load_inference_source
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt' MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
CFG = 'yolov8n.yaml' CFG = 'yolov8n.yaml'
@ -49,28 +49,20 @@ def test_predict_dir():
def test_predict_img(): def test_predict_img():
model = YOLO(MODEL) model = YOLO(MODEL)
output = model(source=Image.open(SOURCE), save=True, verbose=True) # PIL im = cv2.imread(str(SOURCE))
assert len(output) == 1, 'predict test failed' assert len(model(source=Image.open(SOURCE), save=True, verbose=True)) == 1 # PIL
img = cv2.imread(str(SOURCE)) assert len(model(source=im, save=True, save_txt=True)) == 1 # ndarray
output = model(source=img, save=True, save_txt=True) # ndarray assert len(model(source=[im, im], save=True, save_txt=True)) == 2 # batch
assert len(output) == 1, 'predict test failed' assert len(list(model(source=[im, im], save=True, stream=True))) == 2 # stream
output = model(source=[img, img], save=True, save_txt=True) # batch assert len(model(torch.zeros(320, 640, 3).numpy())) == 1 # tensor to numpy
assert len(output) == 2, 'predict test failed' batch = [
output = model(source=[img, img], save=True, stream=True) # stream str(SOURCE), # filename
assert len(list(output)) == 2, 'predict test failed'
tens = torch.zeros(320, 640, 3)
output = model(tens.numpy())
assert len(output) == 1, 'predict test failed'
# test multiple source
imgs = [
SOURCE, # filename
Path(SOURCE), # Path Path(SOURCE), # Path
'https://ultralytics.com/images/zidane.jpg', # URI 'https://ultralytics.com/images/zidane.jpg' if checks.check_online() else SOURCE, # URI
cv2.imread(str(SOURCE)), # OpenCV cv2.imread(str(SOURCE)), # OpenCV
Image.open(SOURCE), # PIL Image.open(SOURCE), # PIL
np.zeros((320, 640, 3))] # numpy np.zeros((320, 640, 3))] # numpy
output = model(imgs) assert len(model(batch)) == len(batch) # multiple sources in a batch
assert len(output) == 6, 'predict test failed!'
def test_predict_grey_and_4ch(): def test_predict_grey_and_4ch():
@ -85,6 +77,11 @@ def test_val():
model.val(data='coco8.yaml', imgsz=32) model.val(data='coco8.yaml', imgsz=32)
def test_val_scratch():
model = YOLO(CFG)
model.val(data='coco8.yaml', imgsz=32)
def test_train_scratch(): def test_train_scratch():
model = YOLO(CFG) model = YOLO(CFG)
model.train(data='coco8.yaml', epochs=1, imgsz=32) model.train(data='coco8.yaml', epochs=1, imgsz=32)
@ -103,6 +100,12 @@ def test_export_torchscript():
YOLO(f)(SOURCE) # exported model inference YOLO(f)(SOURCE) # exported model inference
def test_export_torchscript_scratch():
model = YOLO(CFG)
f = model.export(format='torchscript')
YOLO(f)(SOURCE) # exported model inference
def test_export_onnx(): def test_export_onnx():
model = YOLO(MODEL) model = YOLO(MODEL)
f = model.export(format='onnx') f = model.export(format='onnx')

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.43' __version__ = '8.0.44'
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks from ultralytics.yolo.utils.checks import check_yolo as checks

@ -15,7 +15,7 @@ EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_c
def start(key=''): def start(key=''):
""" """
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY') Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
""" """
auth = Auth(key) auth = Auth(key)
try: try:
@ -30,9 +30,9 @@ def start(key=''):
session = HubTrainingSession(model_id=model_id, auth=auth) session = HubTrainingSession(model_id=model_id, auth=auth)
session.check_disk_space() session.check_disk_space()
trainer = YOLO(session.input_file) model = YOLO(session.input_file)
session.register_callbacks(trainer) session.register_callbacks(model)
trainer.train(**session.train_args) model.train(**session.train_args)
except Exception as e: except Exception as e:
LOGGER.warning(f'{PREFIX}{e}') LOGGER.warning(f'{PREFIX}{e}')
@ -93,6 +93,5 @@ def get_export(key='', format='torchscript'):
return r.json() return r.json()
# temp. For checking
if __name__ == '__main__': if __name__ == '__main__':
start() start()

@ -26,6 +26,7 @@ class HubTrainingSession:
self._timers = {} # rate limit timers (seconds) self._timers = {} # rate limit timers (seconds)
self._metrics_queue = {} # metrics queue self._metrics_queue = {} # metrics queue
self.model = self._get_model() self.model = self._get_model()
self.alive = True
self._start_heartbeat() # start heartbeats self._start_heartbeat() # start heartbeats
self._register_signal_handlers() self._register_signal_handlers()
@ -52,37 +53,6 @@ class HubTrainingSession:
payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'} payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2) smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB
file = None
if Path(weights).is_file():
with open(weights, 'rb') as f:
file = f.read()
if final:
smart_request(
f'{self.api_url}/upload',
data={
'epoch': epoch,
'type': 'final',
'map': map},
files={'best.pt': file},
headers=self.auth_header,
retry=10,
timeout=3600,
code=4,
)
else:
smart_request(
f'{self.api_url}/upload',
data={
'epoch': epoch,
'type': 'epoch',
'isBest': bool(is_best)},
headers=self.auth_header,
files={'last.pt': file},
code=3,
)
def _get_model(self): def _get_model(self):
# Returns model from database by id # Returns model from database by id
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
@ -151,7 +121,7 @@ class HubTrainingSession:
model_info = { model_info = {
'model/parameters': get_num_params(trainer.model), 'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3), 'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed[1], 3)} 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
all_plots = {**all_plots, **model_info} all_plots = {**all_plots, **model_info}
self._metrics_queue[trainer.epoch] = json.dumps(all_plots) self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - self._timers['metrics'] > self._rate_limits['metrics']: if time() - self._timers['metrics'] > self._rate_limits['metrics']:
@ -169,52 +139,45 @@ class HubTrainingSession:
def on_train_end(self, trainer): def on_train_end(self, trainer):
# Upload final model and metrics with exponential standoff # Upload final model and metrics with exponential standoff
LOGGER.info(f'{PREFIX}Training completed successfully ✅') LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
LOGGER.info(f'{PREFIX}Uploading final {self.model_id}') f'{PREFIX}Uploading final {self.model_id}')
# hack for fetching mAP self._upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0)
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
self.alive = False # stop heartbeats self.alive = False # stop heartbeats
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀') LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB # Upload a model to HUB
file = None
if Path(weights).is_file(): if Path(weights).is_file():
with open(weights, 'rb') as f: with open(weights, 'rb') as f:
file = f.read() file = f.read()
file_param = {'best.pt' if final else 'last.pt': file} else:
endpoint = f'{self.api_url}/upload' LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload failed. Missing model {weights}.')
file = None
data = {'epoch': epoch} data = {'epoch': epoch}
if final: if final:
data.update({'type': 'final', 'map': map}) data.update({'type': 'final', 'map': map})
else: else:
data.update({'type': 'epoch', 'isBest': bool(is_best)}) data.update({'type': 'epoch', 'isBest': bool(is_best)})
smart_request( smart_request(f'{self.api_url}/upload',
endpoint,
data=data, data=data,
files=file_param, files={'best.pt' if final else 'last.pt': file},
headers=self.auth_header, headers=self.auth_header,
retry=10 if final else None, retry=10 if final else None,
timeout=3600 if final else None, timeout=3600 if final else None,
code=4 if final else 3, code=4 if final else 3)
)
@threaded @threaded
def _start_heartbeat(self): def _start_heartbeat(self):
self.alive = True
while self.alive: while self.alive:
r = smart_request( r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
json={ json={
'agent': AGENT_NAME, 'agent': AGENT_NAME,
'agentId': self.agent_id}, 'agentId': self.agent_id},
headers=self.auth_header, headers=self.auth_header,
retry=0, retry=0,
code=5, code=5,
thread=False, thread=False)
)
self.agent_id = r.json().get('data', {}).get('agentId', None) self.agent_id = r.json().get('data', {}).get('agentId', None)
sleep(self._rate_limits['heartbeat']) sleep(self._rate_limits['heartbeat'])

@ -181,7 +181,6 @@ class AutoBackend(nn.Module):
import tensorflow as tf import tensorflow as tf
keras = False # assume TF1 saved_model keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
w = Path(w) / 'metadata.yaml'
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf import tensorflow as tf
@ -258,8 +257,9 @@ class AutoBackend(nn.Module):
f'\n\n{EXPORT_FORMATS_TABLE}') f'\n\n{EXPORT_FORMATS_TABLE}')
# Load external metadata YAML # Load external metadata YAML
w = Path(w)
if xml or saved_model or paddle: if xml or saved_model or paddle:
metadata = Path(w).parent / 'metadata.yaml' metadata = (w if saved_model else w.parents[1] if paddle else w.parent) / 'metadata.yaml'
if metadata.exists(): if metadata.exists():
metadata = yaml_load(metadata) metadata = yaml_load(metadata)
stride, names = int(metadata['stride']), metadata['names'] # load metadata stride, names = int(metadata['stride']), metadata['names'] # load metadata

@ -287,6 +287,7 @@ class ClassificationModel(BaseModel):
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.stride = torch.Tensor([1]) # no stride constraints
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
self.info() self.info()
@ -520,14 +521,15 @@ def guess_model_task(model):
# Guess from model filename # Guess from model filename
if isinstance(model, (str, Path)): if isinstance(model, (str, Path)):
model = Path(model).stem model = Path(model)
if '-seg' in model: if '-seg' in model.stem or 'segment' in model.parts:
return 'segment' return 'segment'
elif '-cls' in model: elif '-cls' in model.stem or 'classify' in model.parts:
return 'classify' return 'classify'
else: elif 'detect' in model.parts:
return 'detect' return 'detect'
# Unable to determine task from model # Unable to determine task from model
raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, ' LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
"i.e. 'task=detect', 'task=segment' or 'task=classify'.") "Explicitly define task for your model, i.e. 'task=detect', 'task=segment' or 'task=classify'.")
return 'detect' # assume detect

@ -47,11 +47,12 @@ CLI_HELP_MSG = \
GitHub: https://github.com/ultralytics/ultralytics GitHub: https://github.com/ultralytics/ultralytics
""" """
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'} # Define keys for arg type checks
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'}
CFG_FRACTION_KEYS = { CFG_FRACTION_KEYS = {
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma', 'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'label_smoothing',
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', 'mixup', 'copy_paste',
'mixup', 'copy_paste', 'conf', 'iou'} 'conf', 'iou'} # fractional floats limited to 0.0 - 1.0
CFG_INT_KEYS = { CFG_INT_KEYS = {
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride', 'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
'line_thickness', 'workspace', 'nbs', 'save_period'} 'line_thickness', 'workspace', 'nbs', 'save_period'}
@ -224,7 +225,7 @@ def entrypoint(debug=''):
assert v, f"missing '{k}' value" assert v, f"missing '{k}' value"
if k == 'cfg': # custom.yaml passed if k == 'cfg': # custom.yaml passed
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}') LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'} overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
else: else:
if v.lower() == 'none': if v.lower() == 'none':
v = None v = None
@ -255,7 +256,6 @@ def entrypoint(debug=''):
check_cfg_mismatch(full_args_dict, {a: ''}) check_cfg_mismatch(full_args_dict, {a: ''})
# Defaults # Defaults
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100') task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
# Mode # Mode
@ -272,27 +272,28 @@ def entrypoint(debug=''):
# Model # Model
model = overrides.pop('model', DEFAULT_CFG.model) model = overrides.pop('model', DEFAULT_CFG.model)
task = overrides.pop('task', None)
if model is None: if model is None:
model = task2model.get(task, 'yolov8n.pt') model = 'yolov8n.pt'
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
overrides['model'] = model overrides['model'] = model
model = YOLO(model) model = YOLO(model)
# Task # Task
if task and task != model.task: # if task and task != model.task:
LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. " # LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.") # f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
task = model.task overrides['task'] = overrides.get('task', model.task)
overrides['task'] = task model.task = overrides['task']
# Mode
if mode in {'predict', 'track'} and 'source' not in overrides: if mode in {'predict', 'track'} and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \ overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg' else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'): elif mode in ('train', 'val'):
if 'data' not in overrides: if 'data' not in overrides:
overrides['data'] = task2data.get(task, DEFAULT_CFG.data) overrides['data'] = task2data.get(overrides['task'], DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.") LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.")
elif mode == 'export': elif mode == 'export':
if 'format' not in overrides: if 'format' not in overrides:

@ -6,7 +6,6 @@ Dataloaders and dataset utils
import contextlib import contextlib
import glob import glob
import hashlib import hashlib
import json
import math import math
import os import os
import random import random
@ -27,11 +26,9 @@ from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.data.utils import check_det_dataset
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable, from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
is_kaggle, yaml_load) is_kaggle)
from ultralytics.yolo.utils.checks import check_requirements, check_yaml from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.downloads import unzip_file
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
@ -1037,127 +1034,6 @@ def verify_image_label(args):
return [None, None, None, None, nm, nf, ne, nc, msg] return [None, None, None, None, nm, nf, ne, nc, msg]
class HUBDatasetStats():
""" Class for generating HUB dataset JSON and `-hub` dataset directory
Arguments
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
autodownload: Attempt to download dataset if not found locally
Usage
from ultralytics.yolo.data.dataloaders.v5loader import HUBDatasetStats
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
stats.get_json(save=False)
stats.process_images()
"""
def __init__(self, path='coco128.yaml', autodownload=False):
# Initialize class
zipped, data_dir, yaml_path = self._unzip(Path(path))
# try:
# data = yaml_load(check_yaml(yaml_path)) # data dict
# if zipped:
# data['path'] = data_dir
# except Exception as e:
# raise Exception('error/HUB/dataset_stats/yaml_load') from e
data = check_det_dataset(yaml_path, autodownload) # download dataset if missing
self.hub_dir = Path(str(data['path']) + '-hub')
self.im_dir = self.hub_dir / 'images'
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
self.data = data
@staticmethod
def _find_yaml(dir):
# Return data.yaml file
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
assert files, f'No *.yaml file found in {dir}'
if len(files) > 1:
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
return files[0]
def _unzip(self, path):
# Unzip data.zip
if not str(path).endswith('.zip'): # path is data.yaml
return False, None, path
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
unzip_file(path, path=path.parent)
dir = path.with_suffix('') # dataset directory == zip name
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
def _hub_ops(self, f, max_dim=1920):
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
f_new = self.im_dir / Path(f).name # dataset-hub image filename
try: # use PIL
im = Image.open(f)
r = max_dim / max(im.height, im.width) # ratio
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
except Exception as e: # use OpenCV
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
im = cv2.imread(f)
im_height, im_width = im.shape[:2]
r = max_dim / max(im_height, im_width) # ratio
if r < 1.0: # image too large
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
cv2.imwrite(str(f_new), im)
def get_json(self, save=False, verbose=False):
# Return dataset JSON for Ultralytics HUB
def _round(labels):
# Update labels to integer class and 6 decimal place floats
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
for split in 'train', 'val', 'test':
if self.data.get(split) is None:
self.stats[split] = None # i.e. no test set
continue
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
x = np.array([
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
self.stats[split] = {
'instance_stats': {
'total': int(x.sum()),
'per_class': x.sum(0).tolist()},
'image_stats': {
'total': dataset.n,
'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()},
'labels': [{
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
# Save, print and return
if save:
stats_path = self.hub_dir / 'stats.json'
LOGGER.info(f'Saving {stats_path.resolve()}...')
with open(stats_path, 'w') as f:
json.dump(self.stats, f) # save stats.json
if verbose:
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
return self.stats
def process_images(self):
# Compress images for Ultralytics HUB
for split in 'train', 'val', 'test':
if self.data.get(split) is None:
continue
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
desc = f'{split} images'
total = dataset.n
with ThreadPool(NUM_THREADS) as pool:
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
pass
LOGGER.info(f'Done. All images saved to {self.im_dir}')
return self.im_dir
# Classification dataloaders ------------------------------------------------------------------------------------------- # Classification dataloaders -------------------------------------------------------------------------------------------
class ClassificationDataset(torchvision.datasets.ImageFolder): class ClassificationDataset(torchvision.datasets.ImageFolder):
""" """

@ -2,9 +2,11 @@
import contextlib import contextlib
import hashlib import hashlib
import json
import os import os
import subprocess import subprocess
import time import time
from multiprocessing.pool import ThreadPool
from pathlib import Path from pathlib import Path
from tarfile import is_tarfile from tarfile import is_tarfile
from zipfile import is_zipfile from zipfile import is_zipfile
@ -12,10 +14,11 @@ from zipfile import is_zipfile
import cv2 import cv2
import numpy as np import numpy as np
from PIL import ExifTags, Image, ImageOps from PIL import ExifTags, Image, ImageOps
from tqdm import tqdm
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download, safe_download from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
from ultralytics.yolo.utils.ops import segments2boxes from ultralytics.yolo.utils.ops import segments2boxes
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
@ -290,3 +293,128 @@ def check_cls_dataset(dataset: str):
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names))) names = dict(enumerate(sorted(names)))
return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names} return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
class HUBDatasetStats():
""" Class for generating HUB dataset JSON and `-hub` dataset directory
Arguments
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
autodownload: Attempt to download dataset if not found locally
Usage
from ultralytics.yolo.data.utils import HUBDatasetStats
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco6.zip') # usage 2
stats.get_json(save=False)
stats.process_images()
"""
def __init__(self, path='coco128.yaml', autodownload=False):
# Initialize class
zipped, data_dir, yaml_path = self._unzip(Path(path))
try:
# data = yaml_load(check_yaml(yaml_path)) # data dict
data = check_det_dataset(yaml_path, autodownload) # data dict
if zipped:
data['path'] = data_dir
except Exception as e:
raise Exception('error/HUB/dataset_stats/yaml_load') from e
self.hub_dir = Path(str(data['path']) + '-hub')
self.im_dir = self.hub_dir / 'images'
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
self.data = data
@staticmethod
def _find_yaml(dir):
# Return data.yaml file
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
assert files, f'No *.yaml file found in {dir}'
if len(files) > 1:
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
return files[0]
def _unzip(self, path):
# Unzip data.zip
if not str(path).endswith('.zip'): # path is data.yaml
return False, None, path
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
unzip_file(path, path=path.parent)
dir = path.with_suffix('') # dataset directory == zip name
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
def _hub_ops(self, f, max_dim=1920):
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
f_new = self.im_dir / Path(f).name # dataset-hub image filename
try: # use PIL
im = Image.open(f)
r = max_dim / max(im.height, im.width) # ratio
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
except Exception as e: # use OpenCV
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
im = cv2.imread(f)
im_height, im_width = im.shape[:2]
r = max_dim / max(im_height, im_width) # ratio
if r < 1.0: # image too large
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
cv2.imwrite(str(f_new), im)
def get_json(self, save=False, verbose=False):
# Return dataset JSON for Ultralytics HUB
# from ultralytics.yolo.data import YOLODataset
from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels
def _round(labels):
# Update labels to integer class and 6 decimal place floats
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
for split in 'train', 'val', 'test':
if self.data.get(split) is None:
self.stats[split] = None # i.e. no test set
continue
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
x = np.array([
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
self.stats[split] = {
'instance_stats': {
'total': int(x.sum()),
'per_class': x.sum(0).tolist()},
'image_stats': {
'total': len(dataset),
'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()},
'labels': [{
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
# Save, print and return
if save:
stats_path = self.hub_dir / 'stats.json'
LOGGER.info(f'Saving {stats_path.resolve()}...')
with open(stats_path, 'w') as f:
json.dump(self.stats, f) # save stats.json
if verbose:
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
return self.stats
def process_images(self):
# Compress images for Ultralytics HUB
# from ultralytics.yolo.data import YOLODataset
from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels
for split in 'train', 'val', 'test':
if self.data.get(split) is None:
continue
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
with ThreadPool(NUM_THREADS) as pool:
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
pass
LOGGER.info(f'Done. All images saved to {self.im_dir}')
return self.im_dir

@ -208,12 +208,15 @@ class Exporter:
self.file = file self.file = file
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y) self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.pretty_name = self.file.stem.replace('yolo', 'YOLO') self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
description = f'Ultralytics {self.pretty_name} model' + f'trained on {Path(self.args.data).name}' \
if self.args.data else '(untrained)'
self.metadata = { self.metadata = {
'description': f'Ultralytics {self.pretty_name} model trained on {Path(self.args.data).name}', 'description': description,
'author': 'Ultralytics', 'author': 'Ultralytics',
'license': 'GPL-3.0 https://ultralytics.com/license', 'license': 'GPL-3.0 https://ultralytics.com/license',
'version': __version__, 'version': __version__,
'stride': int(max(model.stride)), 'stride': int(max(model.stride)),
'task': model.task,
'names': model.names} # model metadata 'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and " LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "

@ -9,22 +9,22 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
guess_model_task, nn) guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode from ultralytics.yolo.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes # Map head to model, trainer, validator, and predictor classes
MODEL_MAP = { TASK_MAP = {
'classify': [ 'classify': [
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator', ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
'yolo.TYPE.classify.ClassificationPredictor'], yolo.v8.classify.ClassificationPredictor],
'detect': [ 'detect': [
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator', DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
'yolo.TYPE.detect.DetectionPredictor'], yolo.v8.detect.DetectionPredictor],
'segment': [ 'segment': [
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator', SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
'yolo.TYPE.segment.SegmentationPredictor']} yolo.v8.segment.SegmentationPredictor]}
class YOLO: class YOLO:
@ -33,52 +33,48 @@ class YOLO:
Args: Args:
model (str, Path): Path to the model file to load or create. model (str, Path): Path to the model file to load or create.
type (str): Type/version of models to use. Defaults to "v8".
Attributes: Attributes:
type (str): Type/version of models being used. predictor (Any): The predictor object.
ModelClass (Any): Model class. model (Any): The model object.
TrainerClass (Any): Trainer class. trainer (Any): The trainer object.
ValidatorClass (Any): Validator class. task (str): The type of model task.
PredictorClass (Any): Predictor class. ckpt (Any): The checkpoint object if the model loaded from *.pt file.
predictor (Any): Predictor object. cfg (str): The model configuration if loaded from *.yaml file.
model (Any): Model object. ckpt_path (str): The checkpoint file path.
trainer (Any): Trainer object. overrides (dict): Overrides for the trainer object.
task (str): Type of model task. metrics_data (Any): The data for metrics.
ckpt (Any): Checkpoint object if model loaded from *.pt file.
cfg (str): Model configuration if loaded from *.yaml file.
ckpt_path (str): Checkpoint file path.
overrides (dict): Overrides for trainer object.
metrics_data (Any): Data for metrics.
Methods: Methods:
__call__(): Alias for predict method. __call__(source=None, stream=False, **kwargs):
_new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions. Alias for the predict method.
_load(weights): Initializes a new model and infers the task type from the model head. _new(cfg:str, verbose:bool=True) -> None:
_check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model. Initializes a new model and infers the task type from the model definitions.
reset(): Resets the model modules. _load(weights:str, task:str='') -> None:
info(verbose=False): Logs model info. Initializes a new model and infers the task type from the model head.
fuse(): Fuse model for faster inference. _check_is_pytorch_model() -> None:
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model. Raises TypeError if the model is not a PyTorch model.
reset() -> None:
Resets the model modules.
info(verbose:bool=False) -> None:
Logs the model info.
fuse() -> None:
Fuses the model for faster inference.
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
Performs prediction using the YOLO model.
Returns: Returns:
list(ultralytics.yolo.engine.results.Results): The prediction results. list[ultralytics.yolo.engine.results.Results]: The prediction results.
""" """
def __init__(self, model='yolov8n.pt', type='v8') -> None: def __init__(self, model='yolov8n.pt') -> None:
""" """
Initializes the YOLO model. Initializes the YOLO model.
Args: Args:
model (str, Path): model to load or create model (str, Path): model to load or create
type (str): Type/version of models to use. Defaults to "v8".
""" """
self._reset_callbacks() self._reset_callbacks()
self.type = type
self.ModelClass = None # model class
self.TrainerClass = None # trainer class
self.ValidatorClass = None # validator class
self.PredictorClass = None # predictor class
self.predictor = None # reuse predictor self.predictor = None # reuse predictor
self.model = None # model object self.model = None # model object
self.trainer = None # trainer object self.trainer = None # trainer object
@ -101,6 +97,10 @@ class YOLO:
def __call__(self, source=None, stream=False, **kwargs): def __call__(self, source=None, stream=False, **kwargs):
return self.predict(source, stream, **kwargs) return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def _new(self, cfg: str, verbose=True): def _new(self, cfg: str, verbose=True):
""" """
Initializes a new model and infers the task type from the model definitions. Initializes a new model and infers the task type from the model definitions.
@ -112,11 +112,15 @@ class YOLO:
self.cfg = check_yaml(cfg) # check YAML self.cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
self.task = guess_model_task(cfg_dict) self.task = guess_model_task(cfg_dict)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task() self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1) # initialize
self.overrides['model'] = self.cfg self.overrides['model'] = self.cfg
def _load(self, weights: str): # Below added to allow export from yamls
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
self.model.task = self.task
def _load(self, weights: str, task=''):
""" """
Initializes a new model and infers the task type from the model head. Initializes a new model and infers the task type from the model head.
@ -127,8 +131,7 @@ class YOLO:
if suffix == '.pt': if suffix == '.pt':
self.model, self.ckpt = attempt_load_one_weight(weights) self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args['task'] self.task = self.model.args['task']
self.overrides = self.model.args self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self._reset_ckpt_args(self.overrides)
self.ckpt_path = self.model.pt_path self.ckpt_path = self.model.pt_path
else: else:
weights = check_file(weights) weights = check_file(weights)
@ -136,7 +139,6 @@ class YOLO:
self.task = guess_model_task(weights) self.task = guess_model_task(weights)
self.ckpt_path = weights self.ckpt_path = weights
self.overrides['model'] = weights self.overrides['model'] = weights
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
def _check_is_pytorch_model(self): def _check_is_pytorch_model(self):
""" """
@ -189,12 +191,13 @@ class YOLO:
""" """
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides['conf'] = 0.25 overrides['conf'] = 0.25
overrides.update(kwargs) overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict') overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict'] assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # not save files by default overrides['save'] = kwargs.get('save', False) # not save files by default
if not self.predictor: if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides) self.task = overrides.get('task') or self.task
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
self.predictor.setup_model(model=self.model) self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides) self.predictor.args = get_cfg(self.predictor.args, overrides)
@ -226,12 +229,15 @@ class YOLO:
overrides['mode'] = 'val' overrides['mode'] = 'val'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data args.data = data or args.data
if 'task' in overrides:
self.task = args.task
else:
args.task = self.task args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)): if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
args.imgsz = check_imgsz(args.imgsz, max_dim=1) args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = self.ValidatorClass(args=args) validator = TASK_MAP[self.task][2](args=args)
validator(model=self.model) validator(model=self.model)
self.metrics_data = validator.metrics self.metrics_data = validator.metrics
@ -267,8 +273,7 @@ class YOLO:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch: if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified args.batch = 1 # default to 1 if not modified
exporter = Exporter(overrides=args) return Exporter(overrides=args)(model=self.model)
return exporter(model=self.model)
def train(self, **kwargs): def train(self, **kwargs):
""" """
@ -282,15 +287,15 @@ class YOLO:
overrides.update(kwargs) overrides.update(kwargs)
if kwargs.get('cfg'): if kwargs.get('cfg'):
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.") LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True) overrides = yaml_load(check_yaml(kwargs['cfg']))
overrides['task'] = self.task
overrides['mode'] = 'train' overrides['mode'] = 'train'
if not overrides.get('data'): if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'") raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get('resume'): if overrides.get('resume'):
overrides['resume'] = self.ckpt_path overrides['resume'] = self.ckpt_path
self.trainer = self.TrainerClass(overrides=overrides) self.task = overrides.get('task') or self.task
self.trainer = TASK_MAP[self.task][1](overrides=overrides)
if not overrides.get('resume'): # manually set model only if not resuming if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model self.model = self.trainer.model
@ -311,13 +316,6 @@ class YOLO:
self._check_is_pytorch_model() self._check_is_pytorch_model()
self.model.to(device) self.model.to(device)
def _assign_ops_from_task(self):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
return model_class, trainer_class, validator_class, predictor_class
@property @property
def names(self): def names(self):
""" """
@ -357,9 +355,8 @@ class YOLO:
@staticmethod @staticmethod
def _reset_ckpt_args(args): def _reset_ckpt_args(args):
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \ include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify': return {k: v for k, v in args.items() if k in include}
args.pop(arg, None)
@staticmethod @staticmethod
def _reset_callbacks(): def _reset_callbacks():

@ -108,7 +108,6 @@ class BasePredictor:
def postprocess(self, preds, img, orig_img): def postprocess(self, preds, img, orig_img):
return preds return preds
@smart_inference_mode()
def __call__(self, source=None, model=None, stream=False): def __call__(self, source=None, model=None, stream=False):
if stream: if stream:
return self.stream_inference(source, model) return self.stream_inference(source, model)
@ -136,6 +135,7 @@ class BasePredictor:
self.source_type = self.dataset.source_type self.source_type = self.dataset.source_type
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode()
def stream_inference(self, source=None, model=None): def stream_inference(self, source=None, model=None):
if self.args.verbose: if self.args.verbose:
LOGGER.info('') LOGGER.info('')
@ -161,12 +161,14 @@ class BasePredictor:
self.batch = batch self.batch = batch
path, im, im0s, vid_cap, s = batch path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
# preprocess
with self.dt[0]: with self.dt[0]:
im = self.preprocess(im) im = self.preprocess(im)
if len(im.shape) == 3: if len(im.shape) == 3:
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
# Inference # inference
with self.dt[1]: with self.dt[1]:
preds = self.model(im, augment=self.args.augment, visualize=visualize) preds = self.model(im, augment=self.args.augment, visualize=visualize)

@ -21,26 +21,30 @@ class Results:
A class for storing and manipulating inference results. A class for storing and manipulating inference results.
Args: Args:
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. orig_img (numpy.ndarray): The original image as a numpy array.
masks (Masks, optional): A Masks object containing the detection masks. path (str): The path to the image file.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities. names (List[str]): A list of class names.
orig_img (tuple, optional): Original image size. boxes (List[List[float]], optional): A list of bounding box coordinates for each detection.
masks (numpy.ndarray, optional): A 3D numpy array of detection masks, where each mask is a binary image.
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
Attributes: Attributes:
orig_img (numpy.ndarray): The original image as a numpy array.
orig_shape (tuple): The original image shape in (height, width) format.
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks. masks (Masks, optional): A Masks object containing the detection masks.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities. probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
orig_img (tuple, optional): Original image size. names (List[str]): A list of class names.
data (torch.Tensor): The raw masks tensor path (str): The path to the image file.
_keys (tuple): A tuple of attribute names for non-empty attributes.
""" """
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None: def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
self.orig_img = orig_img self.orig_img = orig_img
self.orig_shape = orig_img.shape[:2] self.orig_shape = orig_img.shape[:2]
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes self.boxes = Boxes(boxes.cpu(), self.orig_shape) if boxes is not None else None # native size boxes
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks self.masks = Masks(masks.cpu(), self.orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = probs if probs is not None else None self.probs = probs.cpu() if probs is not None else None
self.names = names self.names = names
self.path = path self.path = path
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None) self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
@ -99,24 +103,22 @@ class Results:
def __getattr__(self, attr): def __getattr__(self, attr):
name = self.__class__.__name__ name = self.__class__.__name__
raise AttributeError(f""" raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
Attributes:
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
orig_shape (tuple, optional): Original image size.
""")
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
""" """
Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
Args: Args:
show_conf (bool): Show confidence show_conf (bool): Whether to show the detection confidence score.
line_width (Float): The line width of boxes. Automatically scaled to img size if not provided line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
font_size (Float): The font size of . Automatically scaled to img size if not provided font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
Returns:
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
""" """
img = deepcopy(self.orig_img) img = deepcopy(self.orig_img)
annotator = Annotator(img, line_width, font_size, font, pil, example) annotator = Annotator(img, line_width, font_size, font, pil, example)
@ -157,15 +159,24 @@ class Boxes:
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes, boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
with shape (num_boxes, 6). with shape (num_boxes, 6).
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width). orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
is_track (bool): True if the boxes also include track IDs, False otherwise.
Properties: Properties:
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format. xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes. conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes. cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
id (torch.Tensor) or (numpy.ndarray): The track IDs of the boxes (if available).
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format. xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size. xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size. xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
data (torch.Tensor): The raw bboxes tensor data (torch.Tensor): The raw bboxes tensor
Methods:
cpu(): Move the object to CPU memory.
numpy(): Convert the object to a numpy array.
cuda(): Move the object to CUDA memory.
to(*args, **kwargs): Move the object to the specified device.
pandas(): Convert the object to a pandas DataFrame (not yet implemented).
""" """
def __init__(self, boxes, orig_shape) -> None: def __init__(self, boxes, orig_shape) -> None:
@ -257,22 +268,7 @@ class Boxes:
def __getattr__(self, attr): def __getattr__(self, attr):
name = self.__class__.__name__ name = self.__class__.__name__
raise AttributeError(f""" raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
Attributes:
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
with shape (num_boxes, 6).
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
Properties:
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
""")
class Masks: class Masks:
@ -289,6 +285,17 @@ class Masks:
Properties: Properties:
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks. segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
Methods:
cpu(): Returns a copy of the masks tensor on CPU memory.
numpy(): Returns a copy of the masks tensor as a numpy array.
cuda(): Returns a copy of the masks tensor on GPU memory.
to(): Returns a copy of the masks tensor with the specified device and dtype.
__len__(): Returns the number of masks in the tensor.
__str__(): Returns a string representation of the masks tensor.
__repr__(): Returns a detailed string representation of the masks tensor.
__getitem__(): Returns a new Masks object with the masks at the specified index.
__getattr__(): Raises an AttributeError with a list of valid attributes and properties.
""" """
def __init__(self, masks, orig_shape) -> None: def __init__(self, masks, orig_shape) -> None:
@ -337,13 +344,4 @@ class Masks:
def __getattr__(self, attr): def __getattr__(self, attr):
name = self.__class__.__name__ name = self.__class__.__name__
raise AttributeError(f""" raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
Attributes:
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
orig_shape (tuple): Original image size, in the format (height, width).
Properties:
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
""")

@ -44,7 +44,6 @@ class BaseTrainer:
Attributes: Attributes:
args (SimpleNamespace): Configuration for the trainer. args (SimpleNamespace): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint. check_resume (method): Method to check if training should be resumed from a saved checkpoint.
console (logging.Logger): Logger instance.
validator (BaseValidator): Validator instance. validator (BaseValidator): Validator instance.
model (nn.Module): Model instance. model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks. callbacks (defaultdict): Dictionary of callbacks.
@ -84,7 +83,6 @@ class BaseTrainer:
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)
self.device = select_device(self.args.device, self.args.batch) self.device = select_device(self.args.device, self.args.batch)
self.check_resume() self.check_resume()
self.console = LOGGER
self.validator = None self.validator = None
self.model = None self.model = None
self.metrics = None self.metrics = None
@ -180,11 +178,12 @@ class BaseTrainer:
if world_size > 1 and 'LOCAL_RANK' not in os.environ: if world_size > 1 and 'LOCAL_RANK' not in os.environ:
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
try: try:
LOGGER.info(f'Running DDP command {cmd}')
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)
except Exception as e: except Exception as e:
self.console.warning(e) LOGGER.warning(e)
finally: finally:
ddp_cleanup(self, file) ddp_cleanup(self, str(file))
else: else:
self._do_train(RANK, world_size) self._do_train(RANK, world_size)
@ -193,7 +192,7 @@ class BaseTrainer:
# os.environ['MASTER_PORT'] = '9020' # os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank) self.device = torch.device('cuda', rank)
self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}') LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size) dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size): def _setup_train(self, rank, world_size):
@ -262,7 +261,7 @@ class BaseTrainer:
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1 last_opt_step = -1
self.run_callbacks('on_train_start') self.run_callbacks('on_train_start')
self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n" f"Logging results to {colorstr('bold', self.save_dir)}\n"
f'Starting training for {self.epochs} epochs...') f'Starting training for {self.epochs} epochs...')
@ -278,14 +277,14 @@ class BaseTrainer:
pbar = enumerate(self.train_loader) pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional) # Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic): if epoch == (self.epochs - self.args.close_mosaic):
self.console.info('Closing dataloader mosaic') LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'): if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'): if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args) self.train_loader.dataset.close_mosaic(hyp=self.args)
if rank in {-1, 0}: if rank in {-1, 0}:
self.console.info(self.progress_string()) LOGGER.info(self.progress_string())
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT) pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
self.tloss = None self.tloss = None
self.optimizer.zero_grad() self.optimizer.zero_grad()
@ -372,12 +371,11 @@ class BaseTrainer:
if rank in {-1, 0}: if rank in {-1, 0}:
# Do final val with best.pt # Do final val with best.pt
self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in ' LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.') f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
self.final_eval() self.final_eval()
if self.args.plots: if self.args.plots:
self.plot_metrics() self.plot_metrics()
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
self.run_callbacks('on_train_end') self.run_callbacks('on_train_end')
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.run_callbacks('teardown') self.run_callbacks('teardown')
@ -450,18 +448,6 @@ class BaseTrainer:
self.best_fitness = fitness self.best_fitness = fitness
return metrics, fitness return metrics, fitness
def log(self, text, rank=-1):
"""
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Args"
text (str): text to log
rank (List[Int]): process rank
"""
if rank in {-1, 0}:
self.console.info(text)
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
raise NotImplementedError("This task trainer doesn't support loading cfg files") raise NotImplementedError("This task trainer doesn't support loading cfg files")
@ -521,7 +507,7 @@ class BaseTrainer:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
if f is self.best: if f is self.best:
self.console.info(f'\nValidating {f}...') LOGGER.info(f'\nValidating {f}...')
self.metrics = self.validator(model=f) self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None) self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end') self.run_callbacks('on_fit_epoch_end')
@ -564,7 +550,7 @@ class BaseTrainer:
self.best_fitness = best_fitness self.best_fitness = best_fitness
self.start_epoch = start_epoch self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic): if start_epoch > (self.epochs - self.args.close_mosaic):
self.console.info('Closing dataloader mosaic') LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'): if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'): if hasattr(self.train_loader.dataset, 'close_mosaic'):

@ -44,7 +44,6 @@ class BaseValidator:
Attributes: Attributes:
dataloader (DataLoader): Dataloader to use for validation. dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation. pbar (tqdm): Progress bar to update during validation.
logger (logging.Logger): Logger to use for validation.
args (SimpleNamespace): Configuration for the validator. args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate. model (nn.Module): Model to validate.
data (dict): Data dictionary. data (dict): Data dictionary.
@ -56,7 +55,7 @@ class BaseValidator:
save_dir (Path): Directory to save results. save_dir (Path): Directory to save results.
""" """
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
""" """
Initializes a BaseValidator instance. Initializes a BaseValidator instance.
@ -69,14 +68,13 @@ class BaseValidator:
""" """
self.dataloader = dataloader self.dataloader = dataloader
self.pbar = pbar self.pbar = pbar
self.logger = logger or LOGGER
self.args = args or get_cfg(DEFAULT_CFG) self.args = args or get_cfg(DEFAULT_CFG)
self.model = None self.model = None
self.data = None self.data = None
self.device = None self.device = None
self.batch_i = None self.batch_i = None
self.training = True self.training = True
self.speed = None self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.jdict = None self.jdict = None
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
@ -123,14 +121,14 @@ class BaseValidator:
self.device = model.device self.device = model.device
if not pt and not jit: if not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1 self.args.batch = 1 # export.py models default to batch-size 1
self.logger.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'): if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data) self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify': elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data) self.data = check_cls_dataset(self.args.data)
else: else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌")) raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
if self.device.type == 'cpu': if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
@ -179,7 +177,7 @@ class BaseValidator:
stats = self.get_stats() stats = self.get_stats()
self.check_stats(stats) self.check_stats(stats)
self.print_results() self.print_results()
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
self.finalize_metrics() self.finalize_metrics()
self.run_callbacks('on_val_end') self.run_callbacks('on_val_end')
if self.training: if self.training:
@ -187,11 +185,11 @@ class BaseValidator:
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')} results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else: else:
self.logger.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' % LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
self.speed) tuple(self.speed.values()))
if self.args.save_json and self.jdict: if self.args.save_json and self.jdict:
with open(str(self.save_dir / 'predictions.json'), 'w') as f: with open(str(self.save_dir / 'predictions.json'), 'w') as f:
self.logger.info(f'Saving {f.name}...') LOGGER.info(f'Saving {f.name}...')
json.dump(self.jdict, f) # flatten and save json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json: if self.args.plots or self.args.save_json:

@ -60,7 +60,7 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
# Export # Export
if format == '-': if format == '-':
filename = model.ckpt_path filename = model.ckpt_path or model.cfg
export = model # PyTorch format export = model # PyTorch format
else: else:
filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others

@ -29,7 +29,7 @@ def on_pretrain_routine_start(trainer):
auto_connect_frameworks={'pytorch': False}) auto_connect_frameworks={'pytorch': False})
task.connect(vars(trainer.args), name='General') task.connect(vars(trainer.args), name='General')
except Exception as e: except Exception as e:
LOGGER.warning(f'WARNING ⚠️ ClearML not initialized correctly, not logging this run. {e}') LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
def on_train_epoch_end(trainer): def on_train_epoch_end(trainer):
@ -41,9 +41,9 @@ def on_fit_epoch_end(trainer):
task = Task.current_task() task = Task.current_task()
if task and trainer.epoch == 0: if task and trainer.epoch == 0:
model_info = { model_info = {
'Parameters': get_num_params(trainer.model), 'model/parameters': get_num_params(trainer.model),
'GFLOPs': round(get_flops(trainer.model), 3), 'model/GFLOPs': round(get_flops(trainer.model), 3),
'Inference speed (ms/img)': round(trainer.validator.speed[1], 3)} 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
task.connect(model_info, name='Model') task.connect(model_info, name='Model')

@ -16,7 +16,7 @@ def on_pretrain_routine_start(trainer):
experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8') experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
experiment.log_parameters(vars(trainer.args)) experiment.log_parameters(vars(trainer.args))
except Exception as e: except Exception as e:
LOGGER.warning(f'WARNING ⚠️ Comet not initialized correctly, not logging this run. {e}') LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
def on_train_epoch_end(trainer): def on_train_epoch_end(trainer):
@ -36,7 +36,7 @@ def on_fit_epoch_end(trainer):
model_info = { model_info = {
'model/parameters': get_num_params(trainer.model), 'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3), 'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed[1], 3)} 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
experiment.log_metrics(model_info, step=trainer.epoch + 1) experiment.log_metrics(model_info, step=trainer.epoch + 1)

@ -2,17 +2,24 @@
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from ultralytics.yolo.utils import LOGGER
writer = None # TensorBoard SummaryWriter instance writer = None # TensorBoard SummaryWriter instance
def _log_scalars(scalars, step=0): def _log_scalars(scalars, step=0):
if writer:
for k, v in scalars.items(): for k, v in scalars.items():
writer.add_scalar(k, v, step) writer.add_scalar(k, v, step)
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
global writer global writer
try:
writer = SummaryWriter(str(trainer.save_dir)) writer = SummaryWriter(str(trainer.save_dir))
except Exception as e:
writer = None # TensorBoard SummaryWriter instance
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
def on_batch_end(trainer): def on_batch_end(trainer):

@ -254,7 +254,7 @@ def check_file(file, suffix='', download=True):
return file return file
else: # search else: # search
files = [] files = []
for d in 'models', 'datasets', 'tracker/cfg': # search directories for d in 'models', 'datasets', 'tracker/cfg', 'yolo/cfg': # search directories
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
if not files: if not files:
raise FileNotFoundError(f"'{file}' does not exist") raise FileNotFoundError(f"'{file}' does not exist")

@ -51,10 +51,9 @@ def generate_ddp_command(world_size, trainer):
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0]) file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
# Build command # Build command
torch_distributed_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
cmd = [ port = find_free_network_port()
sys.executable, '-m', torch_distributed_cmd, '--nproc_per_node', f'{world_size}', '--master_port', cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
f'{find_free_network_port()}', file] + args
return cmd, file return cmd, file

@ -12,7 +12,7 @@ import requests
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER, checks
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \ GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \ [f'yolov5{size}u.pt' for size in 'nsmlx'] + \
@ -87,7 +87,7 @@ def safe_download(url,
try: try:
if curl or i > 0: # curl download with retry, continue if curl or i > 0: # curl download with retry, continue
s = 'sS' * (not progress) # silent s = 'sS' * (not progress) # silent
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '9', '-C', '-']).returncode r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
assert r == 0, f'Curl return value {r}' assert r == 0, f'Curl return value {r}'
else: # urllib download else: # urllib download
method = 'torch' method = 'torch'
@ -112,8 +112,10 @@ def safe_download(url,
break # success break # success
f.unlink() # remove partial downloads f.unlink() # remove partial downloads
except Exception as e: except Exception as e:
if i >= retry: if i == 0 and not checks.check_online():
raise ConnectionError(f'❌ Download failure for {url}') from e raise ConnectionError(f'❌ Download failure for {url}. Environment is not online.') from e
elif i >= retry:
raise ConnectionError(f'❌ Download failure for {url}. Retry limit reached.') from e
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}: if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:

@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CFG, RANK from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
@ -64,6 +64,7 @@ class ClassificationTrainer(BaseTrainer):
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else: else:
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # dont return ckpt. Classification doesn't support resume return # dont return ckpt. Classification doesn't support resume
@ -93,7 +94,7 @@ class ClassificationTrainer(BaseTrainer):
def get_validator(self): def get_validator(self):
self.loss_names = ['loss'] self.loss_names = ['loss']
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console) return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
def criterion(self, preds, batch): def criterion(self, preds, batch):
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
@ -132,11 +133,12 @@ class ClassificationTrainer(BaseTrainer):
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
# TODO: validate best.pt after training completes # TODO: validate best.pt after training completes
# if f is self.best: # if f is self.best:
# self.console.info(f'\nValidating {f}...') # LOGGER.info(f'\nValidating {f}...')
# self.validator.args.save_json = True # self.validator.args.save_json = True
# self.metrics = self.validator(model=f) # self.metrics = self.validator(model=f)
# self.metrics.pop('fitness', None) # self.metrics.pop('fitness', None)
# self.run_callbacks('on_fit_epoch_end') # self.run_callbacks('on_fit_epoch_end')
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):

@ -2,14 +2,14 @@
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CFG from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
from ultralytics.yolo.utils.metrics import ClassifyMetrics from ultralytics.yolo.utils.metrics import ClassifyMetrics
class ClassificationValidator(BaseValidator): class ClassificationValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, args)
self.args.task = 'classify' self.args.task = 'classify'
self.metrics = ClassifyMetrics() self.metrics = ClassifyMetrics()
@ -31,7 +31,7 @@ class ClassificationValidator(BaseValidator):
self.targets.append(batch['cls']) self.targets.append(batch['cls'])
def finalize_metrics(self, *args, **kwargs): def finalize_metrics(self, *args, **kwargs):
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed)) self.metrics.speed = self.speed
def get_stats(self): def get_stats(self):
self.metrics.process(self.targets, self.pred) self.metrics.process(self.targets, self.pred)
@ -45,7 +45,7 @@ class ClassificationValidator(BaseValidator):
def print_results(self): def print_results(self):
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
self.logger.info(pf % ('all', self.metrics.top1, self.metrics.top5)) LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
def val(cfg=DEFAULT_CFG, use_python=False): def val(cfg=DEFAULT_CFG, use_python=False):

@ -66,10 +66,7 @@ class DetectionTrainer(BaseTrainer):
def get_validator(self): def get_validator(self):
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return v8.detect.DetectionValidator(self.test_loader, return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
save_dir=self.save_dir,
logger=self.console,
args=copy(self.args))
def criterion(self, preds, batch): def criterion(self, preds, batch):
if not hasattr(self, 'compute_loss'): if not hasattr(self, 'compute_loss'):

@ -9,7 +9,7 @@ import torch
from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr, ops
from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images from ultralytics.yolo.utils.plotting import output_to_target, plot_images
@ -18,8 +18,8 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class DetectionValidator(BaseValidator): class DetectionValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, args)
self.args.task = 'detect' self.args.task = 'detect'
self.is_coco = False self.is_coco = False
self.class_map = None self.class_map = None
@ -112,7 +112,7 @@ class DetectionValidator(BaseValidator):
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs): def finalize_metrics(self, *args, **kwargs):
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed)) self.metrics.speed = self.speed
def get_stats(self): def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
@ -123,15 +123,15 @@ class DetectionValidator(BaseValidator):
def print_results(self): def print_results(self):
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
self.logger.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0: if self.nt_per_class.sum() == 0:
self.logger.warning( LOGGER.warning(
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
# Print results per class # Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index): for i, c in enumerate(self.metrics.ap_class_index):
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
if self.args.plots: if self.args.plots:
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values())) self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
@ -212,7 +212,7 @@ class DetectionValidator(BaseValidator):
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions pred_json = self.save_dir / 'predictions.json' # predictions
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6') check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa from pycocotools.coco import COCO # noqa
@ -230,7 +230,7 @@ class DetectionValidator(BaseValidator):
eval.summarize() eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50 stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e: except Exception as e:
self.logger.warning(f'pycocotools unable to run: {e}') LOGGER.warning(f'pycocotools unable to run: {e}')
return stats return stats

@ -68,11 +68,10 @@ class SegmentationPredictor(DetectionPredictor):
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
# Mask plotting # Mask plotting
self.annotator.masks( if self.args.save or self.args.show:
mask.masks, im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute(
colors=[colors(x, True) for x in det.cls], 2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx]
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() / self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu)
255 if self.args.retina_masks else im[idx])
# Write results # Write results
for j, d in enumerate(reversed(det)): for j, d in enumerate(reversed(det)):

@ -32,10 +32,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
def get_validator(self): def get_validator(self):
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
return v8.segment.SegmentationValidator(self.test_loader, return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
save_dir=self.save_dir,
logger=self.console,
args=copy(self.args))
def criterion(self, preds, batch): def criterion(self, preds, batch):
if not hasattr(self, 'compute_loss'): if not hasattr(self, 'compute_loss'):
@ -86,10 +83,6 @@ class SegLoss(Loss):
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
masks = batch['masks'].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
# pboxes # pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
@ -103,10 +96,15 @@ class SegLoss(Loss):
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# bbox loss
if fg_mask.sum(): if fg_mask.sum():
# bbox loss
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
target_scores, target_scores_sum, fg_mask) target_scores, target_scores_sum, fg_mask)
# masks loss
masks = batch['masks'].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
for i in range(batch_size): for i in range(batch_size):
if fg_mask[i].sum(): if fg_mask[i].sum():
mask_idx = target_gt_idx[i][fg_mask[i]] mask_idx = target_gt_idx[i][fg_mask[i]]
@ -121,9 +119,9 @@ class SegLoss(Loss):
marea) # seg loss marea) # seg loss
# WARNING: Uncomment lines below in case of Multi-GPU DDP unused gradient errors # WARNING: Uncomment lines below in case of Multi-GPU DDP unused gradient errors
# else: # else:
# loss[1] += proto.sum() * 0 # loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
# else: # else:
# loss[1] += proto.sum() * 0 # loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
loss[0] *= self.hyp.box # box gain loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box / batch_size # seg gain loss[1] *= self.hyp.box / batch_size # seg gain

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ultralytics.yolo.utils import DEFAULT_CFG, NUM_THREADS, ops from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops
from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images from ultralytics.yolo.utils.plotting import output_to_target, plot_images
@ -16,8 +16,8 @@ from ultralytics.yolo.v8.detect import DetectionValidator
class SegmentationValidator(DetectionValidator): class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, args)
self.args.task = 'segment' self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir) self.metrics = SegmentMetrics(save_dir=self.save_dir)
@ -120,7 +120,7 @@ class SegmentationValidator(DetectionValidator):
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs): def finalize_metrics(self, *args, **kwargs):
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed)) self.metrics.speed = self.speed
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False): def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
""" """
@ -207,7 +207,7 @@ class SegmentationValidator(DetectionValidator):
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions pred_json = self.save_dir / 'predictions.json' # predictions
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6') check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa from pycocotools.coco import COCO # noqa
@ -228,7 +228,7 @@ class SegmentationValidator(DetectionValidator):
stats[self.metrics.keys[idx + 1]], stats[ stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e: except Exception as e:
self.logger.warning(f'pycocotools unable to run: {e}') LOGGER.warning(f'pycocotools unable to run: {e}')
return stats return stats

Loading…
Cancel
Save